Compare commits

...

110 Commits

Author SHA1 Message Date
Pascal Fischer
4d2c774378 refactor networm map generation 2025-03-13 14:29:59 +01:00
Pascal Fischer
ab2e3fec72 expose resource type consts 2025-03-12 13:49:24 +01:00
Hakan Sariman
47f88f7057 Refactor routeIDLookup methods to use Addr() for resolved IP operations 2025-03-11 19:43:58 +08:00
Hakan Sariman
ee33a6ed7c Refactor RemoveLocalPeerStateRoute to eliminate resourceId parameter 2025-03-11 13:19:30 +08:00
Hakan Sariman
da662cfd08 Add source and destination resource IDs to FlowFields 2025-03-11 13:12:54 +08:00
Hakan Sariman
ed2ee1ee9d Merge branch 'feature/flow' into feat/flow-resid 2025-03-11 13:08:11 +08:00
Viktor Liu
76d73548d6 Fix more conflicts 2025-03-10 18:46:01 +01:00
Viktor Liu
11828a064a Fix conflict 2025-03-10 18:35:32 +01:00
Viktor Liu
0c2a3dd937 Merge branch 'main' into feature/flow 2025-03-10 18:30:45 +01:00
Zoltan Papp
cd9eff5331 Increase the timeout to 50 sec (#3481) 2025-03-10 18:23:47 +01:00
Viktor Liu
47dcf8d68c Fix forwarder IP source/destination (#3463) 2025-03-10 14:55:07 +01:00
Viktor Liu
80ceb80197 [client] Ignore candidates that are part of the the wireguard subnet (#3472) 2025-03-10 13:59:21 +01:00
Bethuel Mmbaga
cc8f6bcaf3 [management] Fix tests circular dependency (#3460)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-03-10 15:54:36 +03:00
Zoltan Papp
636a0e2475 [client] Fix engine restart (#3435)
- Refactor the network monitoring to handle one event and it after return
- In the engine restart cancel the upper layer context and the responsibility of the engine stop will be the upper layer
- Before triggering a restart, the engine checks whether the state is already down. This helps avoid unnecessary delayed network restart events.
2025-03-10 13:32:12 +01:00
Viktor Liu
e66e329bf6 [client] Add option to autostart netbird ui in the Windows installer (#3469) 2025-03-10 13:19:17 +01:00
Zoltan Papp
aaa23beeec [client] Prevent to block channel writing (#3474)
The "runningChan" provides feedback to the UI or any client about whether the service is up and running. If the client exits earlier than when the service successfully starts, then this channel causes a block.

- Added timeout for reading the channel to ensure we don't cause blocks for too long for the caller
- Modified channel writing operations to be non-blocking
2025-03-10 13:17:09 +01:00
Zoltan Papp
6bef474e9e [client] Prevent panic in case of double close call (#3475)
Prevent panic in case of double close call
2025-03-10 13:16:28 +01:00
Maycon Santos
81040ff80a [docs] Update typo (#3477) 2025-03-10 11:52:36 +01:00
Viktor Liu
c73481aee4 [client] Enable windows stderr logs by default (#3476) 2025-03-10 11:30:49 +01:00
Hakan Sariman
92286b2541 Implement routeIDLookup for managing local and remote route IDs 2025-03-10 15:58:45 +08:00
Maycon Santos
d8bcf745b0 update integrations 2025-03-09 19:32:38 +01:00
Maycon Santos
8430139d80 fix missing method 2025-03-09 19:03:57 +01:00
Maycon Santos
a2962b4ce0 sync go.sum 2025-03-09 18:50:20 +01:00
Maycon Santos
16fffdb75b sync changes from #3426 2025-03-09 18:48:48 +01:00
Maycon Santos
036cecbf46 update integrations and go mod 2025-03-09 18:47:05 +01:00
Maycon Santos
3482852bb6 sync proto and sum 2025-03-09 18:02:33 +01:00
Maycon Santos
fd62665b1f Merge branch 'main' into feature/flow
# Conflicts:
#	client/cmd/testutil_test.go
#	client/firewall/iptables/router_linux.go
#	client/firewall/nftables/router_linux.go
#	client/firewall/uspfilter/allow_netbird.go
#	client/firewall/uspfilter/allow_netbird_windows.go
#	client/firewall/uspfilter/uspfilter_test.go
#	client/internal/engine.go
#	client/internal/engine_test.go
#	client/server/server_test.go
#	go.mod
#	go.sum
#	management/client/client_test.go
#	management/cmd/management.go
#	management/proto/management.pb.go
#	management/proto/management.proto
#	management/server/account.go
#	management/server/account_test.go
#	management/server/dns_test.go
#	management/server/http/handler.go
#	management/server/http/testing/testing_tools/tools.go
#	management/server/integrations/port_forwarding/controller.go
#	management/server/management_proto_test.go
#	management/server/management_test.go
#	management/server/nameserver_test.go
#	management/server/peer.go
#	management/server/peer_test.go
#	management/server/route_test.go
2025-03-09 17:42:16 +01:00
Viktor Liu
fc1da94520 [client, management] Add port forwarding (#3275)
Add initial support to ingress ports on the client code.

- new types where added
- new protocol messages and controller
2025-03-09 16:06:43 +01:00
Hakan Sariman
1ffe48f0d4 Add nil check in CheckRoutes to prevent potential panic 2025-03-08 12:54:33 +03:00
Hakan Sariman
a3b8a21385 Refactor CheckRoutes to return resource IDs for matching source and destination addresses 2025-03-08 12:26:53 +03:00
Hakan Sariman
86492b88c4 Refactor route handling to simplify route information and improve state management 2025-03-08 12:25:35 +03:00
Hakan Sariman
d08a629f9e Merge branch 'feature/flow' into feat/flow-resid 2025-03-08 12:18:02 +03:00
Viktor Liu
36da464413 Fix tracer test 2025-03-07 17:19:10 +01:00
Hakan Sariman
268e3404d3 Merge branch 'feature/flow' into feat/flow-resid 2025-03-07 18:52:11 +03:00
Hakan Sariman
54d0591833 Refactor route handling to use RouteWithResourceId for improved state management 2025-03-07 18:43:49 +03:00
Viktor Liu
86370a0e7b Use bytes for flows event id (#3439) 2025-03-07 16:12:47 +01:00
Viktor Liu
cb16d0f45f Align packet tracer behavior with actual code paths (#3424) 2025-03-07 14:03:45 +01:00
Viktor Liu
e8d8bd8f18 Add peer traffic rule IDs to allowed connections in flows (#3442) 2025-03-07 13:56:26 +01:00
Viktor Liu
8b07f21c28 Don't track intercepted packets (#3448) 2025-03-07 13:56:16 +01:00
Viktor Liu
54be772ffd Handle flow updates (#3455) 2025-03-07 13:56:00 +01:00
Viktor Liu
3c3a454e61 Fix merge regression 2025-03-06 16:54:15 +01:00
Viktor Liu
5ff77b3595 Add flow userspace counters (#3438) 2025-03-06 16:52:56 +01:00
Viktor Liu
b180edbe5c Track icmp with id only (#3447) 2025-03-06 14:51:23 +01:00
Hakan Sariman
de3b5c78d7 Fix nil pointer dereference in CheckRoutes method 2025-03-06 14:10:31 +03:00
Hakan Sariman
0b42f40cf6 Refactor route management to include resource IDs in state handling 2025-03-06 13:51:46 +03:00
Viktor Liu
0a042ac36d Fix merge conflict 2025-03-05 19:11:20 +01:00
Hakan Sariman
e7f921d787 [client] add resource id fields to netflow events 2025-03-05 20:35:52 +03:00
Viktor Liu
e9f11fb11b Replace net.IP with netip.Addr (#3425) 2025-03-05 18:28:05 +01:00
hakansa
419ed275fa Handle TCP RST flag to transition connection state to closed (#3432) 2025-03-05 18:25:42 +01:00
Viktor Liu
2d4fcaf186 Fix proto numbering (#3436) 2025-03-04 16:57:25 +01:00
Viktor Liu
acf172b52c Add kernel conntrack counters (#3434) 2025-03-04 16:46:03 +01:00
Viktor Liu
8c81a823fa Add flow ACL IDs (#3421) 2025-03-04 16:43:07 +01:00
Maycon Santos
619c549547 sync port forwarding 2025-03-04 16:29:59 +01:00
Maycon Santos
9a713a0987 Merge branch 'feature/port-forwarding' into feature/flow
# Conflicts:
#	go.mod
#	go.sum
2025-03-04 16:28:57 +01:00
Pascal Fischer
c4945cd565 add cleanup scheduler + metrics 2025-03-04 16:21:52 +01:00
Viktor Liu
1e10c17ecb Fix tcp state (#3431) 2025-03-04 11:19:54 +01:00
Viktor Liu
96d5190436 Add icmp type and code to forwarder flow event (#3413) 2025-02-28 21:04:07 +01:00
Viktor Liu
d19c26df06 Fix log direction (#3412) 2025-02-28 21:03:40 +01:00
Viktor Liu
36e36414d9 Fix forwarder log displaying (#3411) 2025-02-28 20:53:01 +01:00
bcmmbaga
7e69589e05 Update management-integrations
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:49:56 +00:00
bcmmbaga
aa613ab79a Update golang.org/x/crypto/ssh
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 19:27:46 +00:00
Viktor Liu
6ead0ff95e Fix log format 2025-02-28 20:24:23 +01:00
Viktor Liu
0db65a8984 Add routed packet drop flow (#3410) 2025-02-28 20:04:59 +01:00
Pascal Fischer
c138807e95 remove log message 2025-02-28 19:54:50 +01:00
Viktor Liu
637c0c8949 Add icmp type and code (#3409) 2025-02-28 19:16:42 +01:00
Viktor Liu
c72e13d8e6 Add conntrack flows (#3406) 2025-02-28 19:16:29 +01:00
Maycon Santos
f6d7bccfa0 Add flow client with sender/receiver (#3405)
add an initial version of receiver client and flow manager receiver and sender
2025-02-28 17:16:18 +00:00
bcmmbaga
e3ed01cafb go mod tidy
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-28 17:10:44 +00:00
Viktor Liu
fa748a7ec2 Add userspace flow implementation (#3393) 2025-02-28 11:08:35 +01:00
Maycon Santos
cccc615783 update flow proto package generated code 2025-02-28 03:09:09 +00:00
Maycon Santos
2021463ca0 update flow proto package name 2025-02-28 02:51:57 +00:00
Maycon Santos
f48cfd52e9 fix logger stop (#3403)
* fix logger stop

* use context to stop receiver

* update test
2025-02-28 00:28:17 +00:00
Pascal Fischer
6838f53f40 add getPeerByIp store method 2025-02-27 19:01:05 +01:00
Maycon Santos
8276236dfa Add netflow manager (#3398)
* Add netflow manager

* fix linter issues
2025-02-27 12:05:20 +00:00
Viktor Liu
994b923d56 Move proto and rename port and icmp info (#3399) 2025-02-27 12:52:33 +01:00
Viktor Liu
59e2432231 Add event proto fields (#3397) 2025-02-27 12:29:50 +01:00
Pascal Fischer
eee0d123e4 [management] add flow settings and credentials (#3389) 2025-02-27 12:17:07 +01:00
Viktor Liu
e943203ae2 Add event fields (#3390)
Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-02-26 12:06:06 +01:00
Pedro Costa
6a775217cf rename flow proto messages 2025-02-25 16:29:54 +00:00
Maycon Santos
175674749f Add memory flow store (#3386) 2025-02-25 15:23:43 +00:00
Pascal Fischer
1e534cecf6 [management] Add flow proto (#3384) 2025-02-25 13:03:27 +01:00
Pedro Costa
aa3aa8c6a8 [management] flow proto 2025-02-25 11:22:54 +00:00
Pascal Fischer
fbdfe45c25 fix merge conflicts on management 2025-02-25 11:57:25 +01:00
Viktor Liu
81ee172db8 Fix route conflict 2025-02-25 11:44:21 +01:00
Viktor Liu
f8fd65a65f Merge branch 'main' into feature/port-forwarding 2025-02-25 11:37:52 +01:00
Bethuel Mmbaga
62b978c050 [management] Add support for tcp/udp allocations (#3381)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-25 10:11:50 +00:00
Bethuel Mmbaga
4ebf1410c6 [management] Add support to allocate same port for public and internal (#3347)
Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-21 11:16:24 +03:00
Viktor Liu
630edf2480 Remove unused var 2025-02-20 13:24:37 +01:00
Viktor Liu
ea469d28d7 Merge branch 'main' into feature/port-forwarding 2025-02-20 13:24:05 +01:00
Pascal Fischer
597f1d47b8 fix management test suite 2025-02-20 13:08:18 +01:00
Viktor Liu
fcc96417f9 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:45:30 +01:00
Viktor Liu
8755211a60 Merge branch 'main' into feature/port-forwarding 2025-02-20 11:39:06 +01:00
Pascal Fischer
e6d4653b08 [management] add cloud tag to get ingress ports api spec (#3300)
* fix tag for get endpoint

* update labels
2025-02-12 16:11:54 +01:00
Zoltan Papp
eb69f2de78 Fix nil pointer exception when load empty list and try to cast it (#3282) 2025-02-06 10:28:42 +01:00
Viktor Liu
206420c085 [client] Fix grouping of peer ACLs with different port ranges (#3289) 2025-02-06 10:28:42 +01:00
Christian Stewart
88a864c195 [relay] Use new upstream for nhooyr.io/websocket package (#3287)
The nhooyr.io/websocket package was renamed to github.com/coder/websocket when
the project was transferred to "coder" as the new maintainer.

Use the new import path and update go.mod and go.sum accordingly.

Signed-off-by: Christian Stewart <christian@aperture.us>
2025-02-06 10:28:42 +01:00
Pascal Fischer
a789e9e6d8 [management] fix duplication detection (#3286) 2025-02-05 21:42:09 +01:00
Viktor Liu
9930913e4e Merge branch 'main' into feature/port-forwarding 2025-02-05 18:55:59 +01:00
Viktor Liu
48675f579f Merge branch 'main' into feature/port-forwarding 2025-02-05 17:44:01 +01:00
Pascal Fischer
afec455f86 [management] copy port info (#3283) 2025-02-05 17:30:42 +01:00
Pascal Fischer
035c5d9f23 [management merge only unique entries on network map merge (#3277) 2025-02-05 16:50:45 +01:00
Viktor Liu
b2a5b29fb2 Merge branch 'main' into feature/port-forwarding 2025-02-05 10:15:37 +01:00
Bethuel Mmbaga
9ec61206c2 [management] Add support for filtering peers by name and IP (#3279)
* add peers ip and name filters

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add get peers filter

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* fix get account peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Extend GetAccountPeers store to support filtering by name and IP

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* Fix get peers references

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
2025-02-05 00:33:15 +03:00
Zoltan Papp
1b011a2d85 [client] Manage the IP forwarding sysctl setting in global way (#3270)
Add new package ipfwdstate that implements reference counting for IP forwarding
state management. This allows multiple usage to safely request IP forwarding
without interfering with each other.
2025-02-03 12:27:18 +01:00
Pascal Fischer
a85ea1ddb0 [manager] ingress ports manager support (#3268)
* add peers manager

* Extend peers manager to support retrieving all peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>

* add network map calc

* move integrations interface

* update management-integrations

* merge main and fix

* go mod tidy

* [management] port forwarding add peer manager fix network map (#3264)

* [management] fix testing tools (#3265)

* Fix net.IPv4 conversion to []byte

* update test to check ipv4

---------

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: bcmmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltán Papp <zoltan.pmail@gmail.com>
2025-02-03 09:37:37 +01:00
Zoltán Papp
829e40d2aa Fix ingress manager unnecessary creation 2025-02-01 10:58:47 +01:00
Pascal Fischer
6344e34880 [management] renamed ingress port endpoints (#3263) 2025-02-01 00:40:33 +01:00
Pascal Fischer
a76ca8c565 Merge branch 'main' into feature/port-forwarding 2025-01-29 22:28:10 +01:00
Zoltan Papp
26693e4ea8 Feature/port forwarding client ingress (#3242)
Client-side forward handling

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>

---------

Co-authored-by: Viktor Liu <17948409+lixmal@users.noreply.github.com>
2025-01-29 16:04:33 +01:00
Pascal Fischer
f6a71f4193 [management] add openapi specs and generate types for port forwarding proxy (#3236) 2025-01-27 17:47:40 +01:00
212 changed files with 10661 additions and 3342 deletions

View File

@@ -258,7 +258,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres', 'mysql' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -325,8 +325,8 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres', 'mysql' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
- name: Install Go - name: Install Go
@@ -392,7 +392,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres' ] store: [ 'sqlite', 'postgres' ]
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:
@@ -461,7 +461,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
arch: [ '386','amd64' ] arch: [ 'amd64' ]
store: [ 'sqlite', 'postgres'] store: [ 'sqlite', 'postgres']
runs-on: ubuntu-22.04 runs-on: ubuntu-22.04
steps: steps:

View File

@@ -0,0 +1,98 @@
package cmd
import (
"fmt"
"sort"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/proto"
)
var forwardingRulesCmd = &cobra.Command{
Use: "forwarding",
Short: "List forwarding rules",
Long: `Commands to list forwarding rules.`,
}
var forwardingRulesListCmd = &cobra.Command{
Use: "list",
Aliases: []string{"ls"},
Short: "List forwarding rules",
Example: " netbird forwarding list",
Long: "Commands to list forwarding rules.",
RunE: listForwardingRules,
}
func listForwardingRules(cmd *cobra.Command, _ []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer conn.Close()
client := proto.NewDaemonServiceClient(conn)
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
if err != nil {
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
}
if len(resp.GetRules()) == 0 {
cmd.Println("No forwarding rules available.")
return nil
}
printForwardingRules(cmd, resp.GetRules())
return nil
}
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
cmd.Println("Available forwarding rules:")
// Sort rules by translated address
sort.Slice(rules, func(i, j int) bool {
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
}
if rules[i].GetProtocol() != rules[j].GetProtocol() {
return rules[i].GetProtocol() < rules[j].GetProtocol()
}
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
})
var lastIP string
for _, rule := range rules {
dPort := portToString(rule.GetDestinationPort())
tPort := portToString(rule.GetTranslatedPort())
if lastIP != rule.GetTranslatedAddress() {
lastIP = rule.GetTranslatedAddress()
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
}
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
}
}
func getFirstPort(portInfo *proto.PortInfo) int {
switch v := portInfo.PortSelection.(type) {
case *proto.PortInfo_Port:
return int(v.Port)
case *proto.PortInfo_Range_:
return int(v.Range.GetStart())
default:
return 0
}
}
func portToString(translatedPort *proto.PortInfo) string {
switch v := translatedPort.PortSelection.(type) {
case *proto.PortInfo_Port:
return fmt.Sprintf("%d", v.Port)
case *proto.PortInfo_Range_:
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
default:
return "No port specified"
}
}

View File

@@ -145,6 +145,7 @@ func init() {
rootCmd.AddCommand(versionCmd) rootCmd.AddCommand(versionCmd)
rootCmd.AddCommand(sshCmd) rootCmd.AddCommand(sshCmd)
rootCmd.AddCommand(networksCMD) rootCmd.AddCommand(networksCMD)
rootCmd.AddCommand(forwardingRulesCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
@@ -153,6 +154,8 @@ func init() {
networksCMD.AddCommand(routesListCmd) networksCMD.AddCommand(routesListCmd)
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd) networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
debugCmd.AddCommand(debugBundleCmd) debugCmd.AddCommand(debugBundleCmd)
debugCmd.AddCommand(logCmd) debugCmd.AddCommand(logCmd)
logCmd.AddCommand(logLevelCmd) logCmd.AddCommand(logLevelCmd)

View File

@@ -10,6 +10,7 @@ import (
"go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -89,13 +90,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics) accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
// either startup error (permanent backoff err) or nil err (successful engine up) // either startup error (permanent backoff err) or nil err (successful engine up)
// TODO: make after-startup backoff err available // TODO: make after-startup backoff err available
run := make(chan error, 1) run := make(chan struct{}, 1)
clientErr := make(chan error, 1)
go func() { go func() {
if err := client.Run(run); err != nil { if err := client.Run(run); err != nil {
run <- err clientErr <- err
} }
}() }()
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err()) return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
} }
return startCtx.Err() return startCtx.Err()
case err := <-run: case err := <-clientErr:
if err != nil { return fmt.Errorf("startup: %w", err)
if stopErr := client.Stop(); stopErr != nil { case <-run:
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
}
return fmt.Errorf("startup: %w", err)
}
} }
c.connect = client c.connect = client

View File

@@ -10,17 +10,18 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes) fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -15,6 +15,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables" nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
"github.com/netbirdio/netbird/client/firewall/uspfilter" "github.com/netbirdio/netbird/client/firewall/uspfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes) return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
} else { } else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes) fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
} }
if errUsp != nil { if errUsp != nil {

View File

@@ -4,12 +4,13 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -30,10 +30,8 @@ type entry struct {
} }
type aclManager struct { type aclManager struct {
iptablesClient *iptables.IPTables iptablesClient *iptables.IPTables
wgIface iFaceMapper wgIface iFaceMapper
routingFwChainName string
entries aclEntries entries aclEntries
optionalEntries map[string][]entry optionalEntries map[string][]entry
ipsetStore *ipsetStore ipsetStore *ipsetStore
@@ -41,12 +39,10 @@ type aclManager struct {
stateManager *statemanager.Manager stateManager *statemanager.Manager
} }
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) { func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
m := &aclManager{ m := &aclManager{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
wgIface: wgIface, wgIface: wgIface,
routingFwChainName: routingFwChainName,
entries: make(map[string][][]string), entries: make(map[string][][]string),
optionalEntries: make(map[string][]entry), optionalEntries: make(map[string][]entry),
ipsetStore: newIpsetStore(), ipsetStore: newIpsetStore(),
@@ -79,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
} }
func (m *aclManager) AddPeerFiltering( func (m *aclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
@@ -314,9 +311,12 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules}) m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...)) m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
// Inbound is handled by our ACLs, the rest is dropped.
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...)) m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
} }
func (m *aclManager) seedInitialOptionalEntries() { func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -13,7 +13,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -31,7 +31,7 @@ type Manager struct {
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
return nil, fmt.Errorf("create router: %w", err) return nil, fmt.Errorf("create router: %w", err)
} }
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD) m.aclMgr, err = newAclManager(iptablesClient, wgIface)
if err != nil { if err != nil {
return nil, fmt.Errorf("create acl manager: %w", err) return nil, fmt.Errorf("create acl manager: %w", err)
} }
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// //
// Comment will be ignored because some system this feature is not supported // Comment will be ignored because some system this feature is not supported
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
_ string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName) return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
} }
_, err := m.AddPeerFiltering( _, err := m.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
"all", "all",
nil, nil,
nil, nil,
firewall.ActionAccept, firewall.ActionAccept,
"", "",
"",
) )
if err != nil { if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err) return fmt.Errorf("allow netbird interface traffic: %w", err)
@@ -226,6 +227,22 @@ func (m *Manager) DisableRouting() error {
return nil return nil
} }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func getConntrackEstablished() []string { func getConntrackEstablished() []string {
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
} }

View File

@@ -10,15 +10,15 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
IsRange: true, IsRange: true,
Values: []uint16{8043, 8046}, Values: []uint16{8043, 8046},
} }
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
for _, r := range rule2 { for _, r := range rule2 {
@@ -97,7 +97,7 @@ func TestIptablesManager(t *testing.T) {
// add second rule // add second rule
ip := net.ParseIP("10.20.0.3") ip := net.ParseIP("10.20.0.3")
port := &fw.Port{Values: []uint16{5353}} port := &fw.Port{Values: []uint16{5353}}
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic") _, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Close(nil) err = manager.Close(nil)
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
port := &fw.Port{ port := &fw.Port{
Values: []uint16{443}, Values: []uint16{443},
} }
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range") rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
for _, r := range rule2 { for _, r := range rule2 {
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set") require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }

View File

@@ -15,7 +15,8 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
@@ -23,22 +24,36 @@ import (
// constants needed to manage and create iptable rules // constants needed to manage and create iptable rules
const ( const (
tableFilter = "filter" tableFilter = "filter"
tableNat = "nat" tableNat = "nat"
tableMangle = "mangle" tableMangle = "mangle"
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING" chainPREROUTING = "PREROUTING"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWD = "NETBIRD-RT-FWD" chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE" chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpPre = "jump-pre" jumpManglePre = "jump-mangle-pre"
jumpNat = "jump-nat" jumpNatPre = "jump-nat-pre"
matchSet = "--match-set" jumpNatPost = "jump-nat-post"
matchSet = "--match-set"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
) )
type ruleInfo struct {
chain string
table string
rule []string
}
type routeFilteringRuleParams struct { type routeFilteringRuleParams struct {
Sources []netip.Prefix Sources []netip.Prefix
Destination netip.Prefix Destination netip.Prefix
@@ -62,6 +77,7 @@ type router struct {
legacyManagement bool legacyManagement bool
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
} }
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
@@ -69,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -104,6 +121,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
} }
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -111,7 +129,7 @@ func (r *router) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }
@@ -139,9 +157,9 @@ func (r *router) AddRouteFiltering(
var err error var err error
if action == firewall.ActionDrop { if action == firewall.ActionDrop {
// after the established rule // after the established rule
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...) err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
} else { } else {
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...) err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
} }
if err != nil { if err != nil {
@@ -156,12 +174,12 @@ func (r *router) AddRouteFiltering(
} }
func (r *router) DeleteRouteRule(rule firewall.Rule) error { func (r *router) DeleteRouteRule(rule firewall.Rule) error {
ruleKey := rule.GetRuleID() ruleKey := rule.ID()
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
setName := r.findSetNameInRule(rule) setName := r.findSetNameInRule(rule)
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("delete route rule: %v", err) return fmt.Errorf("delete route rule: %v", err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
@@ -212,6 +230,10 @@ func (r *router) deleteIpSet(setName string) error {
// AddNatRule inserts an iptables rule pair into the nat chain // AddNatRule inserts an iptables rule pair into the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if r.legacyManagement { if r.legacyManagement {
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination) log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
if err := r.addLegacyRouteRule(pair); err != nil { if err := r.addLegacyRouteRule(pair); err != nil {
@@ -238,6 +260,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains // RemoveNatRule removes an iptables rule pair from forwarding and nat chains
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.removeNatRule(pair); err != nil { if err := r.removeNatRule(pair); err != nil {
return fmt.Errorf("remove nat rule: %w", err) return fmt.Errorf("remove nat rule: %w", err)
} }
@@ -264,7 +290,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
} }
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump} rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
@@ -277,7 +303,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair) ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err) return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
} }
delete(r.rules, ruleKey) delete(r.rules, ruleKey)
@@ -305,7 +331,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) { if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
continue continue
} }
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err)) merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
} else { } else {
delete(r.rules, k) delete(r.rules, k)
@@ -343,9 +369,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
chain string chain string
table string table string
}{ }{
{chainRTFWD, tableFilter}, {chainRTFWDIN, tableFilter},
{chainRTNAT, tableNat}, {chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} { } {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
@@ -365,16 +393,22 @@ func (r *router) createContainers() error {
chain string chain string
table string table string
}{ }{
{chainRTFWD, tableFilter}, {chainRTFWDIN, tableFilter},
{chainRTFWDOUT, tableFilter},
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat},
} { } {
if err := r.createAndSetupChain(chainInfo.chain); err != nil { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
} }
} }
if err := r.insertEstablishedRule(chainRTFWD); err != nil { if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
return fmt.Errorf("insert established rule: %w", err)
}
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
return fmt.Errorf("insert established rule: %w", err) return fmt.Errorf("insert established rule: %w", err)
} }
@@ -415,27 +449,6 @@ func (r *router) addPostroutingRules() error {
return nil return nil
} }
func (r *router) createAndSetupChain(chain string) error {
table := r.getTableForChain(chain)
if err := r.iptablesClient.NewChain(table, chain); err != nil {
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
}
return nil
}
func (r *router) getTableForChain(chain string) string {
switch chain {
case chainRTNAT:
return tableNat
case chainRTPRE:
return tableMangle
default:
return tableFilter
}
}
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished() establishedRule := getConntrackEstablished()
@@ -454,28 +467,43 @@ func (r *router) addJumpRules() error {
// Jump to NAT chain // Jump to NAT chain
natRule := []string{"-j", chainRTNAT} natRule := []string{"-j", chainRTNAT}
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil { if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
return fmt.Errorf("add nat jump rule: %v", err) return fmt.Errorf("add nat postrouting jump rule: %v", err)
} }
r.rules[jumpNat] = natRule r.rules[jumpNatPost] = natRule
// Jump to prerouting chain // Jump to mangle prerouting chain
preRule := []string{"-j", chainRTPRE} preRule := []string{"-j", chainRTPRE}
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil { if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
return fmt.Errorf("add prerouting jump rule: %v", err) return fmt.Errorf("add mangle prerouting jump rule: %v", err)
} }
r.rules[jumpPre] = preRule r.rules[jumpManglePre] = preRule
// Jump to nat prerouting chain
rdrRule := []string{"-j", chainRTRDR}
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
return fmt.Errorf("add nat prerouting jump rule: %v", err)
}
r.rules[jumpNatPre] = rdrRule
return nil return nil
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNat, jumpPre} { for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
table := tableNat var table, chain string
chain := chainPOSTROUTING switch ruleKey {
if ruleKey == jumpPre { case jumpNatPost:
table = tableNat
chain = chainPOSTROUTING
case jumpManglePre:
table = tableMangle table = tableMangle
chain = chainPREROUTING chain = chainPREROUTING
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
} }
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil { if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
@@ -520,6 +548,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
} }
r.rules[ruleKey] = rule r.rules[ruleKey] = rule
r.updateState()
return nil return nil
} }
@@ -535,6 +565,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
log.Debugf("marking rule %s not found", ruleKey) log.Debugf("marking rule %s not found", ruleKey)
} }
r.updateState()
return nil return nil
} }
@@ -564,6 +595,137 @@ func (r *router) updateState() {
} }
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
toDestination := rule.TranslatedAddress.String()
switch {
case len(rule.TranslatedPort.Values) == 0:
// no translated port, use original port
case len(rule.TranslatedPort.Values) == 1:
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
// need the "/originalport" suffix to avoid dnat port randomization
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
default:
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
proto := strings.ToLower(string(rule.Protocol))
rules := make(map[string]ruleInfo, 3)
// DNAT rule
dnatRule := []string{
"!", "-i", r.wgIface.Name(),
"-p", proto,
"-j", "DNAT",
"--to-destination", toDestination,
}
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
rules[ruleKey+dnatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTRDR,
rule: dnatRule,
}
// SNAT rule
snatRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "MASQUERADE",
}
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+snatSuffix] = ruleInfo{
table: tableNat,
chain: chainRTNAT,
rule: snatRule,
}
// Forward filtering rule, if fwd policy is DROP
forwardRule := []string{
"-o", r.wgIface.Name(),
"-p", proto,
"-d", rule.TranslatedAddress.String(),
"-j", "ACCEPT",
}
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
rules[ruleKey+fwdSuffix] = ruleInfo{
table: tableFilter,
chain: chainRTFWDOUT,
rule: forwardRule,
}
for key, ruleInfo := range rules {
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
log.Errorf("rollback failed: %v", rollbackErr)
}
return nil, fmt.Errorf("add rule %s: %w", key, err)
}
r.rules[key] = ruleInfo.rule
}
r.updateState()
return rule, nil
}
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
var merr *multierror.Error
for key, ruleInfo := range rules {
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
// On rollback error, add to rules map for next cleanup
r.rules[key] = ruleInfo.rule
}
}
if merr != nil {
r.updateState()
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
}
delete(r.rules, ruleKey+dnatSuffix)
}
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
}
delete(r.rules, ruleKey+snatSuffix)
}
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
}
delete(r.rules, ruleKey+fwdSuffix)
}
r.updateState()
return nberrors.FormatErrorOrNil(merr)
}
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string { func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
var rule []string var rule []string

View File

@@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
}() }()
// Now 5 rules: // Now 5 rules:
// 1. established rule in forward chain // 1. established rule forward in
// 2. jump rule to NAT chain // 2. estbalished rule forward out
// 3. jump rule to PRE chain // 3. jump rule to POST nat chain
// 4. static outbound masquerade rule // 4. jump rule to PRE mangle chain
// 5. static return masquerade rule // 5. jump rule to PRE nat chain
require.Len(t, manager.rules, 5, "should have created rules map") // 6. static outbound masquerade rule
// 7. static return masquerade rule
require.Len(t, manager.rules, 7, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -328,18 +330,18 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
// Check if the rule is in the internal map // Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()] rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map") assert.True(t, ok, "Rule not found in internal map")
// Log the internal rule // Log the internal rule
t.Logf("Internal rule: %v", rule) t.Logf("Internal rule: %v", rule)
// Check if the rule exists in iptables // Check if the rule exists in iptables
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...) exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
assert.NoError(t, err, "Failed to check rule existence") assert.NoError(t, err, "Failed to check rule existence")
assert.True(t, exists, "Rule not found in iptables") assert.True(t, exists, "Rule not found in iptables")

View File

@@ -12,6 +12,6 @@ type Rule struct {
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) ID() string {
return r.ruleID return r.ruleID
} }

View File

@@ -4,21 +4,20 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }

View File

@@ -26,8 +26,8 @@ const (
// Each firewall type for different OS can use different type // Each firewall type for different OS can use different type
// of the properties to hold data of the created rule // of the properties to hold data of the created rule
type Rule interface { type Rule interface {
// GetRuleID returns the rule id // ID returns the rule id
GetRuleID() string ID() string
} }
// RuleDirection is the traffic direction which a rule is applied // RuleDirection is the traffic direction which a rule is applied
@@ -65,13 +65,13 @@ type Manager interface {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
AddPeerFiltering( AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto Protocol, proto Protocol,
sPort *Port, sPort *Port,
dPort *Port, dPort *Port,
action Action, action Action,
ipsetName string, ipsetName string,
comment string,
) ([]Rule, error) ) ([]Rule, error)
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -80,7 +80,15 @@ type Manager interface {
// IsServerRouteSupported returns true if the firewall supports server side routing operations // IsServerRouteSupported returns true if the firewall supports server side routing operations
IsServerRouteSupported() bool IsServerRouteSupported() bool
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error) AddRouteFiltering(
id []byte,
sources []netip.Prefix,
destination netip.Prefix,
proto Protocol,
sPort *Port,
dPort *Port,
action Action,
) (Rule, error)
// DeleteRouteRule deletes a routing rule // DeleteRouteRule deletes a routing rule
DeleteRouteRule(rule Rule) error DeleteRouteRule(rule Rule) error
@@ -105,6 +113,12 @@ type Manager interface {
EnableRouting() error EnableRouting() error
DisableRouting() error DisableRouting() error
// AddDNATRule adds a DNAT rule
AddDNATRule(ForwardRule) (Rule, error)
// DeleteDNATRule deletes a DNAT rule
DeleteDNATRule(Rule) error
} }
func GenKey(format string, pair RouterPair) string { func GenKey(format string, pair RouterPair) string {

View File

@@ -0,0 +1,27 @@
package manager
import (
"fmt"
"net/netip"
)
// ForwardRule todo figure out better place to this to avoid circular imports
type ForwardRule struct {
Protocol Protocol
DestinationPort Port
TranslatedAddress netip.Addr
TranslatedPort Port
}
func (r ForwardRule) ID() string {
id := fmt.Sprintf("%s;%s;%s;%s",
r.Protocol,
r.DestinationPort.String(),
r.TranslatedAddress.String(),
r.TranslatedPort.String())
return id
}
func (r ForwardRule) String() string {
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
}

View File

@@ -1,30 +1,12 @@
package manager package manager
import ( import (
"fmt"
"strconv" "strconv"
) )
// Protocol is the protocol of the port
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
// ProtocolUnknown unknown protocol
ProtocolUnknown Protocol = "unknown"
)
// Port of the address for firewall rule // Port of the address for firewall rule
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Port struct { type Port struct {
// IsRange is true Values contains two values, the first is the start port, the second is the end port // IsRange is true Values contains two values, the first is the start port, the second is the end port
IsRange bool IsRange bool
@@ -33,6 +15,25 @@ type Port struct {
Values []uint16 Values []uint16
} }
func NewPort(ports ...int) (*Port, error) {
if len(ports) == 0 {
return nil, fmt.Errorf("no port provided")
}
ports16 := make([]uint16, len(ports))
for i, port := range ports {
if port < 1 || port > 65535 {
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
}
ports16[i] = uint16(port)
}
return &Port{
IsRange: len(ports) > 1,
Values: ports16,
}, nil
}
// String interface implementation // String interface implementation
func (p *Port) String() string { func (p *Port) String() string {
var ports string var ports string

View File

@@ -0,0 +1,19 @@
package manager
// Protocol is the protocol of the port
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
type Protocol string
const (
// ProtocolTCP is the TCP protocol
ProtocolTCP Protocol = "tcp"
// ProtocolUDP is the UDP protocol
ProtocolUDP Protocol = "udp"
// ProtocolICMP is the ICMP protocol
ProtocolICMP Protocol = "icmp"
// ProtocolALL cover all supported protocols
ProtocolALL Protocol = "all"
)

View File

@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *AclManager) AddPeerFiltering( func (m *AclManager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
var ipset *nftables.Set var ipset *nftables.Set
if ipsetName != "" { if ipsetName != "" {
@@ -102,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
} }
newRules := make([]firewall.Rule, 0, 2) newRules := make([]firewall.Rule, 0, 2)
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment) ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -127,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err) log.Errorf("failed to delete mangle rule: %v", err)
} }
} }
delete(m.rules, r.GetRuleID()) delete(m.rules, r.ID())
return m.rConn.Flush() return m.rConn.Flush()
} }
@@ -141,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
log.Errorf("failed to delete mangle rule: %v", err) log.Errorf("failed to delete mangle rule: %v", err)
} }
} }
delete(m.rules, r.GetRuleID()) delete(m.rules, r.ID())
return m.rConn.Flush() return m.rConn.Flush()
} }
@@ -176,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
return err return err
} }
delete(m.rules, r.GetRuleID()) delete(m.rules, r.ID())
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name) m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) { if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
@@ -256,7 +256,6 @@ func (m *AclManager) addIOFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipset *nftables.Set, ipset *nftables.Set,
comment string,
) (*Rule, error) { ) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok { if r, ok := m.rules[ruleId]; ok {
@@ -338,7 +337,7 @@ func (m *AclManager) addIOFiltering(
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop}) mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
} }
userData := []byte(strings.Join([]string{ruleId, comment}, " ")) userData := []byte(ruleId)
chain := m.chainInputRules chain := m.chainInputRules
nftRule := m.rConn.AddRule(&nftables.Rule{ nftRule := m.rConn.AddRule(&nftables.Rule{

View File

@@ -14,7 +14,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -29,7 +29,7 @@ const (
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface { type iFaceMapper interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
} }
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", ip.String()) return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
} }
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment) return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String()) return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
} }
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
// DeletePeerRule from the firewall by rule definition // DeletePeerRule from the firewall by rule definition
@@ -342,6 +343,22 @@ func (m *Manager) Flush() error {
return m.aclManager.Flush() return m.aclManager.Flush()
} }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.DeleteDNATRule(rule)
}
func (m *Manager) createWorkTable() (*nftables.Table, error) { func (m *Manager) createWorkTable() (*nftables.Table, error) {
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
if err != nil { if err != nil {

View File

@@ -16,15 +16,15 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
var ifaceMock = &iFaceMock{ var ifaceMock = &iFaceMock{
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
// iFaceMapper defines subset methods of interface required for manager // iFaceMapper defines subset methods of interface required for manager
type iFaceMock struct { type iFaceMock struct {
NameFunc func() string NameFunc func() string
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
} }
func (i *iFaceMock) Name() string { func (i *iFaceMock) Name() string {
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
panic("NameFunc is not set") panic("NameFunc is not set")
} }
func (i *iFaceMock) Address() iface.WGAddress { func (i *iFaceMock) Address() wgaddr.Address {
if i.AddressFunc != nil { if i.AddressFunc != nil {
return i.AddressFunc() return i.AddressFunc()
} }
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
testClient := &nftables.Conn{} testClient := &nftables.Conn{}
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "") rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
err = manager.Flush() err = manager.Flush()
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
NameFunc: func() string { NameFunc: func() string {
return "lo" return "lo"
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.96.0.1"), IP: net.ParseIP("100.96.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.96.0.0"), IP: net.ParseIP("100.96.0.0"),
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
if i%100 == 0 { if i%100 == 0 {
@@ -283,10 +283,11 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
}) })
ip := net.ParseIP("100.96.0.1") ip := net.ParseIP("100.96.0.1")
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule") _, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
require.NoError(t, err, "failed to add peer filtering rule") require.NoError(t, err, "failed to add peer filtering rule")
_, err = manager.AddRouteFiltering( _, err = manager.AddRouteFiltering(
nil,
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")}, []netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
netip.MustParsePrefix("10.1.0.0/24"), netip.MustParsePrefix("10.1.0.0/24"),
fw.ProtocolTCP, fw.ProtocolTCP,

View File

@@ -14,23 +14,31 @@ import (
"github.com/google/nftables" "github.com/google/nftables"
"github.com/google/nftables/binaryutil" "github.com/google/nftables/binaryutil"
"github.com/google/nftables/expr" "github.com/google/nftables/expr"
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/acl/id" nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
const ( const (
chainNameRoutingFw = "netbird-rt-fwd" tableNat = "nat"
chainNameRoutingNat = "netbird-rt-postrouting" chainNameNatPrerouting = "PREROUTING"
chainNameForward = "FORWARD" chainNameRoutingFw = "netbird-rt-fwd"
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
) )
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
@@ -49,16 +57,18 @@ type router struct {
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set] ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
} }
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
r := &router{ r := &router{
conn: &nftables.Conn{}, conn: &nftables.Conn{},
workTable: workTable, workTable: workTable,
chains: make(map[string]*nftables.Chain), chains: make(map[string]*nftables.Chain),
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -98,7 +108,52 @@ func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller // clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear() r.ipsetCounter.Clear()
return r.removeAcceptForwardRules() var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *router) removeNatPreroutingRules() error {
table := &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
}
chain := &nftables.Chain{
Name: chainNameNatPrerouting,
Table: table,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
}
rules, err := r.conn.GetRules(table, chain)
if err != nil {
return fmt.Errorf("get rules from nat table: %w", err)
}
var merr *multierror.Error
// Delete rules that have our UserData suffix
for _, rule := range rules {
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
continue
}
if err := r.conn.DelRule(rule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
return nberrors.FormatErrorOrNil(merr)
} }
func (r *router) loadFilterTable() (*nftables.Table, error) { func (r *router) loadFilterTable() (*nftables.Table, error) {
@@ -133,14 +188,22 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeNAT, Type: nftables.ChainTypeNAT,
}) })
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
Name: chainNameRoutingRdr,
Table: r.workTable,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
Type: nftables.ChainTypeNAT,
})
// Chain is created by acl manager // Chain is created by acl manager
// TODO: move creation to a common place // TODO: move creation to a common place
r.chains[chainNamePrerouting] = &nftables.Chain{ r.chains[chainNamePrerouting] = &nftables.Chain{
Name: chainNamePrerouting, Name: chainNamePrerouting,
Table: r.workTable, Table: r.workTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting, Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityMangle, Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
} }
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
@@ -165,6 +228,7 @@ func (r *router) createContainers() error {
// AddRouteFiltering appends a nftables rule to the routing chain // AddRouteFiltering appends a nftables rule to the routing chain
func (r *router) AddRouteFiltering( func (r *router) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -173,7 +237,7 @@ func (r *router) AddRouteFiltering(
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action) ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
if _, ok := r.rules[string(ruleKey)]; ok { if _, ok := r.rules[string(ruleKey)]; ok {
return ruleKey, nil return ruleKey, nil
} }
@@ -281,7 +345,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
ruleKey := rule.GetRuleID() ruleKey := rule.ID()
nftRule, exists := r.rules[ruleKey] nftRule, exists := r.rules[ruleKey]
if !exists { if !exists {
log.Debugf("route rule %s not found", ruleKey) log.Debugf("route rule %s not found", ruleKey)
@@ -410,6 +474,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
// AddNatRule appends a nftables rule pair to the nat chain // AddNatRule appends a nftables rule pair to the nat chain
func (r *router) AddNatRule(pair firewall.RouterPair) error { func (r *router) AddNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return err
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -836,6 +904,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
// RemoveNatRule removes the prerouting mark rule // RemoveNatRule removes the prerouting mark rule
func (r *router) RemoveNatRule(pair firewall.RouterPair) error { func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
if err := r.refreshRulesMap(); err != nil { if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err) return fmt.Errorf(refreshRulesMapError, err)
} }
@@ -896,6 +968,269 @@ func (r *router) refreshRulesMap() error {
return nil return nil
} }
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if err := r.ipFwdState.RequestForwarding(); err != nil {
return nil, err
}
ruleKey := rule.ID()
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
return rule, nil
}
protoNum, err := protoToInt(rule.Protocol)
if err != nil {
return nil, fmt.Errorf("convert protocol to number: %w", err)
}
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
return nil, err
}
r.addDnatMasq(rule, protoNum, ruleKey)
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
// TODO: find chains with drop policies and add rules there
if err := r.conn.Flush(); err != nil {
return nil, fmt.Errorf("flush rules: %w", err)
}
return &rule, nil
}
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
dnatExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
}
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
// shifted translated port is not supported in nftables, so we hand this over to xtables
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
}
}
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
if err != nil {
return err
}
dnatExprs = append(dnatExprs, additionalExprs...)
dnatExprs = append(dnatExprs,
&expr.NAT{
Type: expr.NATTypeDestNAT,
Family: uint32(nftables.TableFamilyIPv4),
RegAddrMin: 1,
RegProtoMin: regProtoMin,
RegProtoMax: regProtoMax,
},
)
dnatRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingRdr],
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
switch {
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
return r.handlePortRange(rule)
case len(rule.TranslatedPort.Values) == 0:
return r.handleAddressOnly(rule)
case len(rule.TranslatedPort.Values) == 1:
return r.handleSinglePort(rule)
default:
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
}
}
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
&expr.Immediate{
Register: 3,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
},
}
return exprs, 2, 3, nil
}
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
return exprs, 0, 0, nil
}
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
exprs := []expr.Any{
&expr.Immediate{
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
&expr.Immediate{
Register: 2,
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
},
}
return exprs, 2, 0, nil
}
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
dnatExprs = append(dnatExprs,
&expr.Counter{},
&expr.Target{
Name: "DNAT",
Rev: 2,
Info: &xt.NatRange2{
NatRange: xt.NatRange{
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
MinIP: rule.TranslatedAddress.AsSlice(),
MaxIP: rule.TranslatedAddress.AsSlice(),
MinPort: rule.TranslatedPort.Values[0],
MaxPort: rule.TranslatedPort.Values[1],
},
BasePort: rule.DestinationPort.Values[0],
},
},
)
dnatRule := &nftables.Rule{
Table: &nftables.Table{
Name: tableNat,
Family: nftables.TableFamilyIPv4,
},
Chain: &nftables.Chain{
Name: chainNameNatPrerouting,
Table: r.filterTable,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityNATDest,
},
Exprs: dnatExprs,
UserData: []byte(ruleKey + dnatSuffix),
}
r.conn.AddRule(dnatRule)
r.rules[ruleKey+dnatSuffix] = dnatRule
return nil
}
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
masqExprs := []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{protoNum},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 16,
Len: 4,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: rule.TranslatedAddress.AsSlice(),
},
}
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
masqExprs = append(masqExprs, &expr.Masq{})
masqRule := &nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameRoutingNat],
Exprs: masqExprs,
UserData: []byte(ruleKey + snatSuffix),
}
r.conn.AddRule(masqRule)
r.rules[ruleKey+snatSuffix] = masqRule
}
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
log.Errorf("%v", err)
}
ruleKey := rule.ID()
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var merr *multierror.Error
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
if err := r.conn.DelRule(dnatRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
}
}
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
if err := r.conn.DelRule(masqRule); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
}
}
if err := r.conn.Flush(); err != nil {
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
}
if merr == nil {
delete(r.rules, ruleKey+dnatSuffix)
delete(r.rules, ruleKey+snatSuffix)
}
return nberrors.FormatErrorOrNil(merr)
}
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR // generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any { func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
var offset uint32 var offset uint32
@@ -959,15 +1294,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
if port.IsRange && len(port.Values) == 2 { if port.IsRange && len(port.Values) == 2 {
// Handle port range // Handle port range
exprs = append(exprs, exprs = append(exprs,
&expr.Cmp{ &expr.Range{
Op: expr.CmpOpGte, Op: expr.CmpOpEq,
Register: 1, Register: 1,
Data: binaryutil.BigEndian.PutUint16(port.Values[0]), FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
}, ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
&expr.Cmp{
Op: expr.CmpOpLte,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
}, },
) )
} else { } else {

View File

@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action) ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
require.NoError(t, err, "AddRouteFiltering failed") require.NoError(t, err, "AddRouteFiltering failed")
t.Cleanup(func() { t.Cleanup(func() {
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
}) })
// Check if the rule is in the internal map // Check if the rule is in the internal map
rule, ok := r.rules[ruleKey.GetRuleID()] rule, ok := r.rules[ruleKey.ID()]
assert.True(t, ok, "Rule not found in internal map") assert.True(t, ok, "Rule not found in internal map")
t.Log("Internal rule expressions:") t.Log("Internal rule expressions:")
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
var nftRule *nftables.Rule var nftRule *nftables.Rule
for _, rule := range rules { for _, rule := range rules {
if string(rule.UserData) == ruleKey.GetRuleID() { if string(rule.UserData) == ruleKey.ID() {
nftRule = rule nftRule = rule
break break
} }
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 { if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
payloadFound = true payloadFound = true
} }
case *expr.Cmp: case *expr.Range:
if port.IsRange { if port.IsRange && len(port.Values) == 2 {
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte { fromPort := binary.BigEndian.Uint16(ex.FromData)
toPort := binary.BigEndian.Uint16(ex.ToData)
if fromPort == port.Values[0] && toPort == port.Values[1] {
portMatchFound = true portMatchFound = true
} }
} else { }
case *expr.Cmp:
if !port.IsRange {
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 { if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
portValue := binary.BigEndian.Uint16(ex.Data) portValue := binary.BigEndian.Uint16(ex.Data)
for _, p := range port.Values { for _, p := range port.Values {
if uint16(p) == portValue { if p == portValue {
portMatchFound = true portMatchFound = true
break break
} }

View File

@@ -16,6 +16,6 @@ type Rule struct {
} }
// GetRuleID returns the rule id // GetRuleID returns the rule id
func (r *Rule) GetRuleID() string { func (r *Rule) ID() string {
return r.ruleID return r.ruleID
} }

View File

@@ -3,21 +3,20 @@ package nftables
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/device"
) )
type InterfaceState struct { type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress iface.WGAddress `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
return i.NameStr return i.NameStr
} }
func (i *InterfaceState) Address() device.WGAddress { func (i *InterfaceState) Address() wgaddr.Address {
return i.WGAddress return i.WGAddress
} }

View File

@@ -4,6 +4,7 @@ package uspfilter
import ( import (
"context" "context"
"net/netip"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -16,8 +17,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
@@ -31,8 +32,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.tcpTracker.Close() m.tcpTracker.Close()
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@@ -3,12 +3,14 @@ package uspfilter
import ( import (
"context" "context"
"fmt" "fmt"
"net/netip"
"os/exec" "os/exec"
"syscall" "syscall"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -20,28 +22,31 @@ const (
firewallRuleName = "Netbird" firewallRuleName = "Netbird"
) )
// Close closes the firewall manager // Reset firewall to the default state
func (m *Manager) Close(*statemanager.Manager) error { func (m *Manager) Close(*statemanager.Manager) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
m.outgoingRules = make(map[string]RuleSet) m.outgoingRules = make(map[netip.Addr]RuleSet)
m.incomingRules = make(map[string]RuleSet) m.incomingRules = make(map[netip.Addr]RuleSet)
if m.udpTracker != nil { if m.udpTracker != nil {
m.udpTracker.Close() m.udpTracker.Close()
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
} }
if m.icmpTracker != nil { if m.icmpTracker != nil {
m.icmpTracker.Close() m.icmpTracker.Close()
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
} }
if m.tcpTracker != nil { if m.tcpTracker != nil {
m.tcpTracker.Close() m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
} }
if m.forwarder != nil { if fwder := m.forwarder.Load(); fwder != nil {
m.forwarder.Stop() fwder.Stop()
} }
if m.logger != nil { if m.logger != nil {

View File

@@ -3,14 +3,14 @@ package common
import ( import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// IFaceMapper defines subset methods of interface required for manager // IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface { type IFaceMapper interface {
SetFilter(device.PacketFilter) error SetFilter(device.PacketFilter) error
Address() iface.WGAddress Address() wgaddr.Address
GetWGDevice() *wgdevice.Device GetWGDevice() *wgdevice.Device
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice
} }

View File

@@ -1,20 +1,27 @@
// common.go
package conntrack package conntrack
import ( import (
"net" "fmt"
"sync" "net/netip"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// BaseConnTrack provides common fields and locking for all connection types // BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct { type BaseConnTrack struct {
SourceIP net.IP FlowId uuid.UUID
DestIP net.IP Direction nftypes.Direction
SourcePort uint16 SourceIP netip.Addr
DestPort uint16 DestIP netip.Addr
lastSeen atomic.Int64 // Unix nano for atomic access lastSeen atomic.Int64
PacketsTx atomic.Uint64
PacketsRx atomic.Uint64
BytesTx atomic.Uint64
BytesRx atomic.Uint64
} }
// these small methods will be inlined by the compiler // these small methods will be inlined by the compiler
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano()) b.lastSeen.Store(time.Now().UnixNano())
} }
// UpdateCounters safely updates the packet and byte counters
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
if direction == nftypes.Egress {
b.PacketsTx.Add(1)
b.BytesTx.Add(uint64(bytes))
} else {
b.PacketsRx.Add(1)
b.BytesRx.Add(uint64(bytes))
}
}
// GetLastSeen safely gets the last seen timestamp // GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time { func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load()) return time.Unix(0, b.lastSeen.Load())
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
return time.Since(lastSeen) > timeout return time.Since(lastSeen) > timeout
} }
// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte
// MakeIPAddr creates an IPAddr from net.IP
func MakeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}
// ConnKey uniquely identifies a connection // ConnKey uniquely identifies a connection
type ConnKey struct { type ConnKey struct {
SrcIP IPAddr SrcIP netip.Addr
DstIP IPAddr DstIP netip.Addr
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
} }
// makeConnKey creates a connection key func (c ConnKey) String() string {
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
return ConnKey{
SrcIP: MakeIPAddr(srcIP),
DstIP: MakeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}
// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}
// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
ip := make(net.IP, 16)
return &ip
},
},
}
}
// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return *p.Pool.Get().(*net.IP)
}
// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(&ip)
}
// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
} }

View File

@@ -1,94 +1,67 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip"
"testing" "testing"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func BenchmarkIPOperations(b *testing.B) {
b.Run("MakeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = MakeIPAddr(ip)
}
})
b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := MakeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})
b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})
}
// Memory pressure tests // Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) { func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
} }
} }
}) })
b.Run("UDPHighLoad", func(b *testing.B) { b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
// Generate different IPs // Generate different IPs
srcIPs := make([]net.IP, 100) srcIPs := make([]netip.Addr, 100)
dstIPs := make([]net.IP, 100) dstIPs := make([]netip.Addr, 100)
for i := 0; i < 100; i++ { for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs) srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs) dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
// Simulate some valid inbound packets // Simulate some valid inbound packets
if i%3 == 0 { if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
} }
} }
}) })

View File

@@ -1,13 +1,17 @@
package conntrack package conntrack
import ( import (
"net" "context"
"fmt"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -19,18 +23,20 @@ const (
// ICMPConnKey uniquely identifies an ICMP connection // ICMPConnKey uniquely identifies an ICMP connection
type ICMPConnKey struct { type ICMPConnKey struct {
// Supports both IPv4 and IPv6 SrcIP netip.Addr
SrcIP [16]byte DstIP netip.Addr
DstIP [16]byte ID uint16
Sequence uint16 // ICMP sequence number }
ID uint16 // ICMP identifier
func (i ICMPConnKey) String() string {
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
} }
// ICMPConnTrack represents an ICMP connection state // ICMPConnTrack represents an ICMP connection state
type ICMPConnTrack struct { type ICMPConnTrack struct {
BaseConnTrack BaseConnTrack
Sequence uint16 ICMPType uint8
ID uint16 ICMPCode uint8
} }
// ICMPTracker manages ICMP connection states // ICMPTracker manages ICMP connection states
@@ -39,131 +45,201 @@ type ICMPTracker struct {
connections map[ICMPConnKey]*ICMPConnTrack connections map[ICMPConnKey]*ICMPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{} flowLogger nftypes.FlowLogger
ipPool *PreallocatedIPs
} }
// NewICMPTracker creates a new ICMP connection tracker // NewICMPTracker creates a new ICMP connection tracker
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultICMPTimeout timeout = DefaultICMPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &ICMPTracker{ tracker := &ICMPTracker{
logger: logger, logger: logger,
connections: make(map[ICMPConnKey]*ICMPConnTrack), connections: make(map[ICMPConnKey]*ICMPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(ICMPCleanupInterval), cleanupTicker: time.NewTicker(ICMPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := ICMPConnKey{
key := makeICMPKey(srcIP, dstIP, id, seq) SrcIP: srcIP,
DstIP: dstIP,
t.mutex.Lock() ID: id,
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
},
ID: id,
Sequence: seq,
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New ICMP connection %v", key)
} }
t.mutex.Unlock()
conn.UpdateLastSeen()
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := makeICMPKey(dstIP, srcIP, id, seq)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if exists {
return false conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
} }
if conn.timeoutExceeded(t.timeout) { return key, false
return false
}
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.ID == id &&
conn.Sequence == seq
} }
func (t *ICMPTracker) cleanupRoutine() { // TrackOutbound records an outbound ICMP connection
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
}
}
// TrackInbound records an inbound ICMP Echo Request
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
}
// track is the common implementation for tracking both inbound and outbound ICMP connections
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
if exists {
return
}
typ, code := typecode.Type(), typecode.Code()
// non echo requests don't need tracking
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
return
}
conn := &ICMPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
ICMPType: typ,
ICMPCode: code,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
t.sendEvent(nftypes.TypeStart, conn, ruleId)
}
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
return false
}
key := ICMPConnKey{
SrcIP: dstIP,
DstIP: srcIP,
ID: id,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false
}
conn.UpdateLastSeen()
conn.UpdateCounters(nftypes.Ingress, size)
return true
}
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
defer t.tickerCancel()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
} }
func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) cleanup() {
t.mutex.Lock() t.mutex.Lock()
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Debug("Removed ICMP connection %v (timeout)", key) t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *ICMPTracker) Close() { func (t *ICMPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
// makeICMPKey creates an ICMP connection key func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { t.flowLogger.StoreEvent(nftypes.EventFields{
return ICMPConnKey{ FlowID: conn.FlowId,
SrcIP: MakeIPAddr(srcIP), Type: typ,
DstIP: MakeIPAddr(dstIP), RuleID: ruleID,
ID: id, Direction: conn.Direction,
Sequence: seq, Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
} SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
ICMPType: conn.ICMPType,
ICMPCode: conn.ICMPCode,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
fields := nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeStart,
RuleID: ruleID,
Direction: direction,
Protocol: nftypes.ICMP,
SourceIP: srcIP,
DestIP: dstIP,
ICMPType: typ,
ICMPCode: code,
}
if direction == nftypes.Ingress {
fields.RxPackets = 1
fields.RxBytes = uint64(size)
} else {
fields.TxPackets = 1
fields.TxBytes = uint64(size)
}
t.flowLogger.StoreEvent(fields)
} }

View File

@@ -1,39 +1,39 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
) )
func BenchmarkICMPTracker(b *testing.B) { func BenchmarkICMPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewICMPTracker(DefaultICMPTimeout, logger) tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
} }
}) })
} }

View File

@@ -3,12 +3,16 @@ package conntrack
// TODO: Send RST packets for invalid/timed-out connections // TODO: Send RST packets for invalid/timed-out connections
import ( import (
"net" "context"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -39,6 +43,35 @@ const (
// TCPState represents the state of a TCP connection // TCPState represents the state of a TCP connection
type TCPState int type TCPState int
func (s TCPState) String() string {
switch s {
case TCPStateNew:
return "New"
case TCPStateSynSent:
return "SYN Sent"
case TCPStateSynReceived:
return "SYN Received"
case TCPStateEstablished:
return "Established"
case TCPStateFinWait1:
return "FIN Wait 1"
case TCPStateFinWait2:
return "FIN Wait 2"
case TCPStateClosing:
return "Closing"
case TCPStateTimeWait:
return "Time Wait"
case TCPStateCloseWait:
return "Close Wait"
case TCPStateLastAck:
return "Last ACK"
case TCPStateClosed:
return "Closed"
default:
return "Unknown"
}
}
const ( const (
TCPStateNew TCPState = iota TCPStateNew TCPState = iota
TCPStateSynSent TCPStateSynSent
@@ -53,19 +86,14 @@ const (
TCPStateClosed TCPStateClosed
) )
// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}
// TCPConnTrack represents a TCP connection state // TCPConnTrack represents a TCP connection state
type TCPConnTrack struct { type TCPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
State TCPState State TCPState
established atomic.Bool established atomic.Bool
tombstone atomic.Bool
sync.RWMutex sync.RWMutex
} }
@@ -79,78 +107,126 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
t.established.Store(state) t.established.Store(state)
} }
// IsTombstone safely checks if the connection is marked for deletion
func (t *TCPConnTrack) IsTombstone() bool {
return t.tombstone.Load()
}
// SetTombstone safely marks the connection for deletion
func (t *TCPConnTrack) SetTombstone() {
t.tombstone.Store(true)
}
// TCPTracker manages TCP connection states // TCPTracker manages TCP connection states
type TCPTracker struct { type TCPTracker struct {
logger *nblog.Logger logger *nblog.Logger
connections map[ConnKey]*TCPConnTrack connections map[ConnKey]*TCPConnTrack
mutex sync.RWMutex mutex sync.RWMutex
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
done chan struct{} tickerCancel context.CancelFunc
timeout time.Duration timeout time.Duration
ipPool *PreallocatedIPs flowLogger nftypes.FlowLogger
} }
// NewTCPTracker creates a new TCP connection tracker // NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
if timeout == 0 {
timeout = DefaultTCPTimeout
}
ctx, cancel := context.WithCancel(context.Background())
tracker := &TCPTracker{ tracker := &TCPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*TCPConnTrack), connections: make(map[ConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval), cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
timeout: timeout, timeout: timeout,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound processes an outbound TCP packet and updates connection state func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { key := ConnKey{
// Create key before lock SrcIP: srcIP,
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if exists {
conn.Lock()
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
conn.Unlock()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// TrackOutbound records an outbound TCP connection
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
}
}
// TrackInbound processes an inbound TCP packet and updates connection state
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
if exists {
return
}
conn := &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.established.Store(false)
conn.tombstone.Store(false)
t.logger.Trace("New %s TCP connection: %s", direction, key)
t.updateState(key, conn, flags, direction == nftypes.Egress)
t.mutex.Lock() t.mutex.Lock()
conn, exists := t.connections[key] t.connections[key] = conn
if !exists {
// Use preallocated IPs
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &TCPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
State: TCPStateNew,
}
conn.UpdateLastSeen()
conn.established.Store(false)
t.connections[key] = conn
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
}
t.mutex.Unlock() t.mutex.Unlock()
// Lock individual connection for state update t.sendEvent(nftypes.TypeStart, conn, ruleID)
conn.Lock()
t.updateState(conn, flags, true)
conn.Unlock()
conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound TCP packet matches a tracked connection // IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
if !isValidFlagCombination(flags) { key := ConnKey{
return false SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
} }
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
@@ -159,22 +235,26 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
return false return false
} }
// Handle RST packets // Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 { if flags&TCPRst != 0 {
conn.Lock() if conn.IsTombstone() {
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock()
return true return true
} }
conn.Lock()
conn.SetTombstone()
conn.State = TCPStateClosed
conn.SetEstablished(false)
conn.Unlock() conn.Unlock()
return false conn.UpdateCounters(nftypes.Ingress, size)
t.logger.Trace("TCP connection reset: %s", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
return true
} }
conn.Lock() conn.Lock()
t.updateState(conn, flags, false) t.updateState(key, conn, flags, false)
conn.UpdateLastSeen()
isEstablished := conn.IsEstablished() isEstablished := conn.IsEstablished()
isValidState := t.isValidStateForFlags(conn.State, flags) isValidState := t.isValidStateForFlags(conn.State, flags)
conn.Unlock() conn.Unlock()
@@ -183,18 +263,17 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
} }
// updateState updates the TCP connection state based on flags // updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) { func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed conn.UpdateLastSeen()
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetEstablished(false)
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d", state := conn.State
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) defer func() {
return if state != conn.State {
} t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
}
}()
switch conn.State { switch state {
case TCPStateNew: case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 { if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent conn.State = TCPStateSynSent
@@ -203,11 +282,11 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateSynSent: case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 { if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound { if isOutbound {
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateEstablished conn.State = TCPStateEstablished
conn.SetEstablished(true) conn.SetEstablished(true)
} else {
// Simultaneous open
conn.State = TCPStateSynReceived
} }
} }
@@ -225,22 +304,32 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
conn.State = TCPStateCloseWait conn.State = TCPStateCloseWait
} }
conn.SetEstablished(false) conn.SetEstablished(false)
} else if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait1: case TCPStateFinWait1:
switch { switch {
case flags&TCPFin != 0 && flags&TCPAck != 0: case flags&TCPFin != 0 && flags&TCPAck != 0:
// Simultaneous close - both sides sent FIN
conn.State = TCPStateClosing conn.State = TCPStateClosing
case flags&TCPFin != 0: case flags&TCPFin != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPAck != 0: case flags&TCPAck != 0:
conn.State = TCPStateFinWait2 conn.State = TCPStateFinWait2
case flags&TCPRst != 0:
conn.State = TCPStateClosed
conn.SetTombstone()
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateFinWait2: case TCPStateFinWait2:
if flags&TCPFin != 0 { if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
t.logger.Trace("TCP connection %s completed", key)
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateClosing: case TCPStateClosing:
@@ -248,8 +337,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
conn.State = TCPStateTimeWait conn.State = TCPStateTimeWait
// Keep established = false from previous state // Keep established = false from previous state
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d", t.logger.Trace("TCP connection %s closed (simultaneous)", key)
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
case TCPStateCloseWait: case TCPStateCloseWait:
@@ -260,17 +349,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
case TCPStateLastAck: case TCPStateLastAck:
if flags&TCPAck != 0 { if flags&TCPAck != 0 {
conn.State = TCPStateClosed conn.State = TCPStateClosed
conn.SetTombstone()
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d", // Send close event for gracefully closed connections
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.sendEvent(nftypes.TypeEnd, conn, nil)
t.logger.Trace("TCP connection %s closed gracefully", key)
} }
case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
} }
} }
@@ -315,12 +399,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
return false return false
} }
func (t *TCPTracker) cleanupRoutine() { func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -331,6 +417,12 @@ func (t *TCPTracker) cleanup() {
defer t.mutex.Unlock() defer t.mutex.Unlock()
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.IsTombstone() {
// Clean up tombstoned connections without sending an event
delete(t.connections, key)
continue
}
var timeout time.Duration var timeout time.Duration
switch { switch {
case conn.State == TCPStateTimeWait: case conn.State == TCPStateTimeWait:
@@ -341,29 +433,26 @@ func (t *TCPTracker) cleanup() {
timeout = TCPHandshakeTimeout timeout = TCPHandshakeTimeout
} }
lastSeen := conn.GetLastSeen() if conn.timeoutExceeded(timeout) {
if time.Since(lastSeen) > timeout {
// Return IPs to pool // Return IPs to pool
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
// event already handled by state change
if conn.State != TCPStateTimeWait {
t.sendEvent(nftypes.TypeEnd, conn, nil)
}
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() { func (t *TCPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
// Clean up all remaining IPs // Clean up all remaining IPs
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
@@ -381,3 +470,21 @@ func isValidFlagCombination(flags uint8) bool {
return true return true
} }
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.TCP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,7 @@
package conntrack package conntrack
import ( import (
"net" "net/netip"
"testing" "testing"
"time" "time"
@@ -9,11 +9,11 @@ import (
) )
func TestTCPStateMachine(t *testing.T) { func TestTCPStateMachine(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags) isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
require.Equal(t, !tt.wantDrop, isValid, tt.desc) require.Equal(t, !tt.wantDrop, isValid, tt.desc)
}) })
} }
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
t.Helper() t.Helper()
// Send initial SYN // Send initial SYN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
// Receive SYN-ACK // Receive SYN-ACK
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
// Send ACK // Send ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
// Test data transfer // Test data transfer
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
require.True(t, valid, "Data should be allowed after handshake") require.True(t, valid, "Data should be allowed after handshake")
}, },
}, },
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Send FIN // Send FIN
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
// Receive ACK for FIN // Receive ACK for FIN
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "ACK for FIN should be allowed") require.True(t, valid, "ACK for FIN should be allowed")
// Receive FIN from other side // Receive FIN from other side
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "FIN should be allowed") require.True(t, valid, "FIN should be allowed")
// Send final ACK // Send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
}, },
{ {
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Receive RST // Receive RST
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
require.True(t, valid, "RST should be allowed for established connection") require.True(t, valid, "RST should be allowed for established connection")
// Connection is logically dead but we don't enforce blocking subsequent packets // Connection is logically dead but we don't enforce blocking subsequent packets
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort) establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
// Both sides send FIN+ACK // Both sides send FIN+ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
require.True(t, valid, "Simultaneous FIN should be allowed") require.True(t, valid, "Simultaneous FIN should be allowed")
// Both sides send final ACK // Both sides send final ACK
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck) valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
require.True(t, valid, "Final ACKs should be allowed") require.True(t, valid, "Final ACKs should be allowed")
}, },
}, },
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Helper() t.Helper()
tracker = NewTCPTracker(DefaultTCPTimeout, logger) tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
tt.test(t) tt.test(t)
}) })
} }
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
} }
func TestRSTHandling(t *testing.T) { func TestRSTHandling(t *testing.T) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("100.64.0.1") srcIP := netip.MustParseAddr("100.64.0.1")
dstIP := net.ParseIP("100.64.0.2") dstIP := netip.MustParseAddr("100.64.0.2")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(80) dstPort := uint16(80)
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
name: "RST in established", name: "RST in established",
setupState: func() { setupState: func() {
// Establish connection first // Establish connection first
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
}, },
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: true, wantValid: true,
desc: "Should accept RST for established connection", desc: "Should accept RST for established connection",
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
name: "RST without connection", name: "RST without connection",
setupState: func() {}, setupState: func() {},
sendRST: func() { sendRST: func() {
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst) tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
}, },
wantValid: false, wantValid: false,
desc: "Should reject RST without connection", desc: "Should reject RST without connection",
@@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
tt.sendRST() tt.sendRST()
// Verify connection state is as expected // Verify connection state is as expected
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn := tracker.connections[key] conn := tracker.connections[key]
if tt.wantValid { if tt.wantValid {
require.NotNil(t, conn) require.NotNil(t, conn)
@@ -220,63 +225,63 @@ func TestRSTHandling(t *testing.T) {
} }
// Helper to establish a TCP connection // Helper to establish a TCP connection
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) { func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
t.Helper() t.Helper()
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck) valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
require.True(t, valid, "SYN-ACK should be allowed") require.True(t, valid, "SYN-ACK should be allowed")
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
} }
func BenchmarkTCPTracker(b *testing.B) { func BenchmarkTCPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
} }
}) })
b.Run("ConcurrentAccess", func(b *testing.B) { b.Run("ConcurrentAccess", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout, logger) tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
i := 0 i := 0
for pb.Next() { for pb.Next() {
if i%2 == 0 { if i%2 == 0 {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
} else { } else {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
} }
i++ i++
} }
@@ -287,14 +292,14 @@ func BenchmarkTCPTracker(b *testing.B) {
// Benchmark connection cleanup // Benchmark connection cleanup
func BenchmarkCleanup(b *testing.B) { func BenchmarkCleanup(b *testing.B) {
b.Run("TCPCleanup", func(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) {
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
defer tracker.Close() defer tracker.Close()
// Pre-populate with expired connections // Pre-populate with expired connections
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
} }
// Wait for connections to expire // Wait for connections to expire

View File

@@ -1,11 +1,15 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip"
"sync" "sync"
"time" "time"
"github.com/google/uuid"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -18,6 +22,8 @@ const (
// UDPConnTrack represents a UDP connection state // UDPConnTrack represents a UDP connection state
type UDPConnTrack struct { type UDPConnTrack struct {
BaseConnTrack BaseConnTrack
SourcePort uint16
DestPort uint16
} }
// UDPTracker manages UDP connection states // UDPTracker manages UDP connection states
@@ -26,89 +32,125 @@ type UDPTracker struct {
connections map[ConnKey]*UDPConnTrack connections map[ConnKey]*UDPConnTrack
timeout time.Duration timeout time.Duration
cleanupTicker *time.Ticker cleanupTicker *time.Ticker
tickerCancel context.CancelFunc
mutex sync.RWMutex mutex sync.RWMutex
done chan struct{} flowLogger nftypes.FlowLogger
ipPool *PreallocatedIPs
} }
// NewUDPTracker creates a new UDP connection tracker // NewUDPTracker creates a new UDP connection tracker
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
if timeout == 0 { if timeout == 0 {
timeout = DefaultUDPTimeout timeout = DefaultUDPTimeout
} }
ctx, cancel := context.WithCancel(context.Background())
tracker := &UDPTracker{ tracker := &UDPTracker{
logger: logger, logger: logger,
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(UDPCleanupInterval), cleanupTicker: time.NewTicker(UDPCleanupInterval),
done: make(chan struct{}), tickerCancel: cancel,
ipPool: NewPreallocatedIPs(), flowLogger: flowLogger,
} }
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
return tracker return tracker
} }
// TrackOutbound records an outbound UDP connection // TrackOutbound records an outbound UDP connection
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
// if (inverted direction) conn is not tracked, track this direction
t.mutex.Lock() t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
conn, exists := t.connections[key]
if !exists {
srcIPCopy := t.ipPool.Get()
dstIPCopy := t.ipPool.Get()
copyIP(srcIPCopy, srcIP)
copyIP(dstIPCopy, dstIP)
conn = &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
SourceIP: srcIPCopy,
DestIP: dstIPCopy,
SourcePort: srcPort,
DestPort: dstPort,
},
}
conn.UpdateLastSeen()
t.connections[key] = conn
t.logger.Trace("New UDP connection: %v", conn)
} }
t.mutex.Unlock()
conn.UpdateLastSeen()
} }
// IsValidInbound checks if an inbound packet matches a tracked connection // TrackInbound records an inbound UDP connection
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
key := makeConnKey(dstIP, srcIP, dstPort, srcPort) t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
}
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
t.mutex.RLock() t.mutex.RLock()
conn, exists := t.connections[key] conn, exists := t.connections[key]
t.mutex.RUnlock() t.mutex.RUnlock()
if !exists { if exists {
conn.UpdateLastSeen()
conn.UpdateCounters(direction, size)
return key, true
}
return key, false
}
// track is the common implementation for tracking both inbound and outbound connections
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
if exists {
return
}
conn := &UDPConnTrack{
BaseConnTrack: BaseConnTrack{
FlowId: uuid.New(),
Direction: direction,
SourceIP: srcIP,
DestIP: dstIP,
},
SourcePort: srcPort,
DestPort: dstPort,
}
conn.UpdateLastSeen()
t.mutex.Lock()
t.connections[key] = conn
t.mutex.Unlock()
t.logger.Trace("New %s UDP connection: %s", direction, key)
t.sendEvent(nftypes.TypeStart, conn, ruleID)
}
// IsValidInbound checks if an inbound packet matches a tracked connection
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
key := ConnKey{
SrcIP: dstIP,
DstIP: srcIP,
SrcPort: dstPort,
DstPort: srcPort,
}
t.mutex.RLock()
conn, exists := t.connections[key]
t.mutex.RUnlock()
if !exists || conn.timeoutExceeded(t.timeout) {
return false return false
} }
if conn.timeoutExceeded(t.timeout) { conn.UpdateLastSeen()
return false conn.UpdateCounters(nftypes.Ingress, size)
}
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && return true
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
conn.DestPort == srcPort &&
conn.SourcePort == dstPort
} }
// cleanupRoutine periodically removes stale connections // cleanupRoutine periodically removes stale connections
func (t *UDPTracker) cleanupRoutine() { func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
defer t.cleanupTicker.Stop()
for { for {
select { select {
case <-t.cleanupTicker.C: case <-t.cleanupTicker.C:
t.cleanup() t.cleanup()
case <-t.done: case <-ctx.Done():
return return
} }
} }
@@ -120,44 +162,58 @@ func (t *UDPTracker) cleanup() {
for key, conn := range t.connections { for key, conn := range t.connections {
if conn.timeoutExceeded(t.timeout) { if conn.timeoutExceeded(t.timeout) {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
delete(t.connections, key) delete(t.connections, key)
t.logger.Trace("Removed UDP connection %v (timeout)", conn) t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
t.sendEvent(nftypes.TypeEnd, conn, nil)
} }
} }
} }
// Close stops the cleanup routine and releases resources // Close stops the cleanup routine and releases resources
func (t *UDPTracker) Close() { func (t *UDPTracker) Close() {
t.cleanupTicker.Stop() t.tickerCancel()
close(t.done)
t.mutex.Lock() t.mutex.Lock()
for _, conn := range t.connections {
t.ipPool.Put(conn.SourceIP)
t.ipPool.Put(conn.DestIP)
}
t.connections = nil t.connections = nil
t.mutex.Unlock() t.mutex.Unlock()
} }
// GetConnection safely retrieves a connection state // GetConnection safely retrieves a connection state
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) { func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
t.mutex.RLock() t.mutex.RLock()
defer t.mutex.RUnlock() defer t.mutex.RUnlock()
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
conn, exists := t.connections[key] SrcIP: srcIP,
if !exists { DstIP: dstIP,
return nil, false SrcPort: srcPort,
DstPort: dstPort,
} }
conn, exists := t.connections[key]
return conn, true return conn, exists
} }
// Timeout returns the configured timeout duration for the tracker // Timeout returns the configured timeout duration for the tracker
func (t *UDPTracker) Timeout() time.Duration { func (t *UDPTracker) Timeout() time.Duration {
return t.timeout return t.timeout
} }
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
t.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: conn.FlowId,
Type: typ,
RuleID: ruleID,
Direction: conn.Direction,
Protocol: nftypes.UDP,
SourceIP: conn.SourceIP,
DestIP: conn.DestIP,
SourcePort: conn.SourcePort,
DestPort: conn.DestPort,
RxPackets: conn.PacketsRx.Load(),
TxPackets: conn.PacketsTx.Load(),
RxBytes: conn.BytesRx.Load(),
TxBytes: conn.BytesTx.Load(),
})
}

View File

@@ -1,7 +1,8 @@
package conntrack package conntrack
import ( import (
"net" "context"
"net/netip"
"testing" "testing"
"time" "time"
@@ -29,54 +30,59 @@ func TestNewUDPTracker(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tracker := NewUDPTracker(tt.timeout, logger) tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
assert.NotNil(t, tracker) assert.NotNil(t, tracker)
assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.Equal(t, tt.wantTimeout, tracker.timeout)
assert.NotNil(t, tracker.connections) assert.NotNil(t, tracker.connections)
assert.NotNil(t, tracker.cleanupTicker) assert.NotNil(t, tracker.cleanupTicker)
assert.NotNil(t, tracker.done) assert.NotNil(t, tracker.tickerCancel)
}) })
} }
} }
func TestUDPTracker_TrackOutbound(t *testing.T) { func TestUDPTracker_TrackOutbound(t *testing.T) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
// Verify connection was tracked // Verify connection was tracked
key := makeConnKey(srcIP, dstIP, srcPort, dstPort) key := ConnKey{
SrcIP: srcIP,
DstIP: dstIP,
SrcPort: srcPort,
DstPort: dstPort,
}
conn, exists := tracker.connections[key] conn, exists := tracker.connections[key]
require.True(t, exists) require.True(t, exists)
assert.True(t, conn.SourceIP.Equal(srcIP)) assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
assert.True(t, conn.DestIP.Equal(dstIP)) assert.True(t, conn.DestIP.Compare(dstIP) == 0)
assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, srcPort, conn.SourcePort)
assert.Equal(t, dstPort, conn.DestPort) assert.Equal(t, dstPort, conn.DestPort)
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
} }
func TestUDPTracker_IsValidInbound(t *testing.T) { func TestUDPTracker_IsValidInbound(t *testing.T) {
tracker := NewUDPTracker(1*time.Second, logger) tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.2") srcIP := netip.MustParseAddr("192.168.1.2")
dstIP := net.ParseIP("192.168.1.3") dstIP := netip.MustParseAddr("192.168.1.3")
srcPort := uint16(12345) srcPort := uint16(12345)
dstPort := uint16(53) dstPort := uint16(53)
// Track outbound connection // Track outbound connection
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort) tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
tests := []struct { tests := []struct {
name string name string
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
sleep time.Duration sleep time.Duration
@@ -93,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
}, },
{ {
name: "invalid source IP", name: "invalid source IP",
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: srcIP, dstIP: srcIP,
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
@@ -103,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
{ {
name: "invalid destination IP", name: "invalid destination IP",
srcIP: dstIP, srcIP: dstIP,
dstIP: net.ParseIP("192.168.1.4"), dstIP: netip.MustParseAddr("192.168.1.4"),
srcPort: dstPort, srcPort: dstPort,
dstPort: srcPort, dstPort: srcPort,
sleep: 0, sleep: 0,
@@ -143,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
if tt.sleep > 0 { if tt.sleep > 0 {
time.Sleep(tt.sleep) time.Sleep(tt.sleep)
} }
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort) got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
}) })
} }
@@ -154,42 +160,45 @@ func TestUDPTracker_Cleanup(t *testing.T) {
timeout := 50 * time.Millisecond timeout := 50 * time.Millisecond
cleanupInterval := 25 * time.Millisecond cleanupInterval := 25 * time.Millisecond
ctx, tickerCancel := context.WithCancel(context.Background())
defer tickerCancel()
// Create tracker with custom cleanup interval // Create tracker with custom cleanup interval
tracker := &UDPTracker{ tracker := &UDPTracker{
connections: make(map[ConnKey]*UDPConnTrack), connections: make(map[ConnKey]*UDPConnTrack),
timeout: timeout, timeout: timeout,
cleanupTicker: time.NewTicker(cleanupInterval), cleanupTicker: time.NewTicker(cleanupInterval),
done: make(chan struct{}), tickerCancel: tickerCancel,
ipPool: NewPreallocatedIPs(),
logger: logger, logger: logger,
flowLogger: flowLogger,
} }
// Start cleanup routine // Start cleanup routine
go tracker.cleanupRoutine() go tracker.cleanupRoutine(ctx)
// Add some connections // Add some connections
connections := []struct { connections := []struct {
srcIP net.IP srcIP netip.Addr
dstIP net.IP dstIP netip.Addr
srcPort uint16 srcPort uint16
dstPort uint16 dstPort uint16
}{ }{
{ {
srcIP: net.ParseIP("192.168.1.2"), srcIP: netip.MustParseAddr("192.168.1.2"),
dstIP: net.ParseIP("192.168.1.3"), dstIP: netip.MustParseAddr("192.168.1.3"),
srcPort: 12345, srcPort: 12345,
dstPort: 53, dstPort: 53,
}, },
{ {
srcIP: net.ParseIP("192.168.1.4"), srcIP: netip.MustParseAddr("192.168.1.4"),
dstIP: net.ParseIP("192.168.1.5"), dstIP: netip.MustParseAddr("192.168.1.5"),
srcPort: 12346, srcPort: 12346,
dstPort: 53, dstPort: 53,
}, },
} }
for _, conn := range connections { for _, conn := range connections {
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort) tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
} }
// Verify initial connections // Verify initial connections
@@ -211,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
func BenchmarkUDPTracker(b *testing.B) { func BenchmarkUDPTracker(b *testing.B) {
b.Run("TrackOutbound", func(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
} }
}) })
b.Run("IsValidInbound", func(b *testing.B) { b.Run("IsValidInbound", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout, logger) tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
defer tracker.Close() defer tracker.Close()
srcIP := net.ParseIP("192.168.1.1") srcIP := netip.MustParseAddr("192.168.1.1")
dstIP := net.ParseIP("192.168.1.2") dstIP := netip.MustParseAddr("192.168.1.2")
// Pre-populate some connections // Pre-populate some connections
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
} }
}) })
} }

View File

@@ -1,6 +1,8 @@
package forwarder package forwarder
import ( import (
"fmt"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
return true return true
} }
type epID stack.TransportEndpointID
func (i epID) String() string {
// src and remote is swapped
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -29,6 +30,7 @@ const (
type Forwarder struct { type Forwarder struct {
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
stack *stack.Stack stack *stack.Stack
endpoint *endpoint endpoint *endpoint
udpForwarder *udpForwarder udpForwarder *udpForwarder
@@ -38,7 +40,7 @@ type Forwarder struct {
netstack bool netstack bool
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &Forwarder{ f := &Forwarder{
logger: logger, logger: logger,
flowLogger: flowLogger,
stack: s, stack: s,
endpoint: endpoint, endpoint: endpoint,
udpForwarder: newUDPForwarder(mtu, logger), udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
netstack: netstack, netstack: netstack,

View File

@@ -3,14 +3,30 @@ package forwarder
import ( import (
"context" "context"
"net" "net"
"net/netip"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleICMP handles ICMP packets from the network stack // handleICMP handles ICMP packets from the network stack
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
icmpType := uint8(icmpHdr.Type())
icmpCode := uint8(icmpHdr.Code())
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
// dont process our own replies
return true
}
flowID := uuid.New()
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
defer cancel() defer cancel()
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// TODO: support non-root // TODO: support non-root
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
if err != nil { if err != nil {
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
// This will make netstack reply on behalf of the original destination, that's ok for now // This will make netstack reply on behalf of the original destination, that's ok for now
return false return false
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
dstIP := f.determineDialAddr(id.LocalAddress) dstIP := f.determineDialAddr(id.LocalAddress)
dst := &net.IPAddr{IP: dstIP} dst := &net.IPAddr{IP: dstIP}
// Get the complete ICMP message (header + data)
fullPacket := stack.PayloadSince(pkt.TransportHeader()) fullPacket := stack.PayloadSince(pkt.TransportHeader())
payload := fullPacket.AsSlice() payload := fullPacket.AsSlice()
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) if _, err = conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
epID(id), icmpHdr.Type(), icmpHdr.Code())
// For Echo Requests, send and handle response // For Echo Requests, send and handle response
switch icmpHdr.Type() { if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
case header.ICMPv4Echo: f.handleEchoResponse(icmpHdr, conn, id)
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id) f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
default:
} }
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) // For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
return true return true
} }
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool { func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err) f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return true return
} }
response := make([]byte, f.endpoint.mtu) response := make([]byte, f.endpoint.mtu)
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if !isTimeout(err) { if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err) f.logger.Error("Failed to read ICMP response: %v", err)
} }
return true return
} }
ipHdr := make([]byte, header.IPv4MinimumSize) ipHdr := make([]byte, header.IPv4MinimumSize)
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
if err := f.InjectIncomingPacket(fullPacket); err != nil { if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err) f.logger.Error("Failed to inject ICMP response: %v", err)
return true
return
} }
f.logger.Trace("Forwarded ICMP echo reply for %v", id) f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
return true epID(id), icmpHdr.Type(), icmpHdr.Code())
}
// sendICMPEvent stores flow events for ICMP packets
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
f.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.ICMP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
ICMPType: icmpType,
ICMPCode: icmpCode,
// TODO: get packets/bytes
})
} }

View File

@@ -5,24 +5,38 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
// handleTCP is called by the TCP forwarder for new connections. // handleTCP is called by the TCP forwarder for new connections.
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID() id := r.ID()
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if err != nil { if err != nil {
r.Complete(true) r.Complete(true)
f.logger.Trace("forwarder: dial error for %v: %v", id, err) f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
return return
} }
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
inConn := gonet.NewTCPConn(&wq, ep) inConn := gonet.NewTCPConn(&wq, ep)
f.logger.Trace("forwarder: established TCP connection %v", id) success = true
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
go f.proxyTCP(id, inConn, outConn, ep) go f.proxyTCP(id, inConn, outConn, ep, flowID)
} }
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
defer func() { defer func() {
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: inConn close error: %v", err) f.logger.Debug("forwarder: inConn close error: %v", err)
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.logger.Debug("forwarder: outConn close error: %v", err) f.logger.Debug("forwarder: outConn close error: %v", err)
} }
ep.Close() ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
}() }()
// Create context for managing the proxy goroutines // Create context for managing the proxy goroutines
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
select { select {
case <-ctx.Done(): case <-ctx.Done():
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id) f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
return return
case err := <-errChan: case err := <-errChan:
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
f.logger.Error("proxyTCP: copy error: %v", err) f.logger.Error("proxyTCP: copy error: %v", err)
} }
f.logger.Trace("forwarder: tearing down TCP connection %v", id) f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
return return
} }
} }
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.TCP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.SegmentsSent.Value()
fields.TxPackets = tcpStats.SegmentsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}

View File

@@ -5,10 +5,12 @@ import (
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/google/uuid"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -16,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/waiter" "gvisor.dev/gvisor/pkg/waiter"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
) )
const ( const (
@@ -28,15 +31,17 @@ type udpPacketConn struct {
lastSeen atomic.Int64 lastSeen atomic.Int64
cancel context.CancelFunc cancel context.CancelFunc
ep tcpip.Endpoint ep tcpip.Endpoint
flowID uuid.UUID
} }
type udpForwarder struct { type udpForwarder struct {
sync.RWMutex sync.RWMutex
logger *nblog.Logger logger *nblog.Logger
conns map[stack.TransportEndpointID]*udpPacketConn flowLogger nftypes.FlowLogger
bufPool sync.Pool conns map[stack.TransportEndpointID]*udpPacketConn
ctx context.Context bufPool sync.Pool
cancel context.CancelFunc ctx context.Context
cancel context.CancelFunc
} }
type idleConn struct { type idleConn struct {
@@ -44,13 +49,14 @@ type idleConn struct {
conn *udpPacketConn conn *udpPacketConn
} }
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder { func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
logger: logger, logger: logger,
conns: make(map[stack.TransportEndpointID]*udpPacketConn), flowLogger: flowLogger,
ctx: ctx, conns: make(map[stack.TransportEndpointID]*udpPacketConn),
cancel: cancel, ctx: ctx,
cancel: cancel,
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() any { New: func() any {
b := make([]byte, mtu) b := make([]byte, mtu)
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
for id, conn := range f.conns { for id, conn := range f.conns {
conn.cancel() conn.cancel()
if err := conn.conn.Close(); err != nil { if err := conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
} }
if err := conn.outConn.Close(); err != nil { if err := conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
conn.ep.Close() conn.ep.Close()
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
for _, idle := range idleConns { for _, idle := range idleConns {
idle.conn.cancel() idle.conn.cancel()
if err := idle.conn.conn.Close(); err != nil { if err := idle.conn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err) f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
} }
if err := idle.conn.outConn.Close(); err != nil { if err := idle.conn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
} }
idle.conn.ep.Close() idle.conn.ep.Close()
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
delete(f.conns, idle.id) delete(f.conns, idle.id)
f.Unlock() f.Unlock()
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
} }
} }
} }
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
_, exists := f.udpForwarder.conns[id] _, exists := f.udpForwarder.conns[id]
f.udpForwarder.RUnlock() f.udpForwarder.RUnlock()
if exists { if exists {
f.logger.Trace("forwarder: existing UDP connection for %v", id) f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
return return
} }
flowID := uuid.New()
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
var success bool
defer func() {
if !success {
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
}
}()
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
if err != nil { if err != nil {
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
// TODO: Send ICMP error message // TODO: Send ICMP error message
return return
} }
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
if epErr != nil { if epErr != nil {
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr) f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
outConn: outConn, outConn: outConn,
cancel: connCancel, cancel: connCancel,
ep: ep, ep: ep,
flowID: flowID,
} }
pConn.updateLastSeen() pConn.updateLastSeen()
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
pConn.cancel() pConn.cancel()
if err := inConn.Close(); err != nil { if err := inConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := outConn.Close(); err != nil { if err := outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
return return
} }
f.udpForwarder.conns[id] = pConn f.udpForwarder.conns[id] = pConn
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.logger.Trace("forwarder: established UDP connection to %v", id) success = true
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
go f.proxyUDP(connCtx, pConn, id, ep) go f.proxyUDP(connCtx, pConn, id, ep)
} }
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
defer func() { defer func() {
pConn.cancel() pConn.cancel()
if err := pConn.conn.Close(); err != nil { if err := pConn.conn.Close(); err != nil {
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
} }
if err := pConn.outConn.Close(); err != nil { if err := pConn.outConn.Close(); err != nil {
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
} }
ep.Close() ep.Close()
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
f.udpForwarder.Lock() f.udpForwarder.Lock()
delete(f.udpForwarder.conns, id) delete(f.udpForwarder.conns, id)
f.udpForwarder.Unlock() f.udpForwarder.Unlock()
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
}() }()
errChan := make(chan error, 2) errChan := make(chan error, 2)
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
select { select {
case <-ctx.Done(): case <-ctx.Done():
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id) f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
return return
case err := <-errChan: case err := <-errChan:
if err != nil && !isClosedError(err) { if err != nil && !isClosedError(err) {
f.logger.Error("proxyUDP: copy error: %v", err) f.logger.Error("proxyUDP: copy error: %v", err)
} }
f.logger.Trace("forwarder: tearing down UDP connection %v", id) f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
return return
} }
} }
// sendUDPEvent stores flow events for UDP connections
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
fields := nftypes.EventFields{
FlowID: flowID,
Type: typ,
Direction: nftypes.Ingress,
Protocol: nftypes.UDP,
// TODO: handle ipv6
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
SourcePort: id.RemotePort,
DestPort: id.LocalPort,
}
if ep != nil {
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
// fields are flipped since this is the in conn
// TODO: get bytes
fields.RxPackets = tcpStats.PacketsSent.Value()
fields.TxPackets = tcpStats.PacketsReceived.Value()
}
}
f.flowLogger.StoreEvent(fields)
}
func (c *udpPacketConn) updateLastSeen() { func (c *udpPacketConn) updateLastSeen() {
c.lastSeen.Store(time.Now().UnixNano()) c.lastSeen.Store(time.Now().UnixNano())
} }

View File

@@ -3,6 +3,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
m.ipv4Bitmap[high] |= 1 << (low % 32) m.ipv4Bitmap[high] |= 1 << (low % 32)
} }
func (m *localIPManager) checkBitmapBit(ip net.IP) bool { func (m *localIPManager) checkBitmapBit(ip []byte) bool {
ipv4 := ip.To4() high := (uint16(ip[0]) << 8) | uint16(ip[1])
if ipv4 == nil { low := (uint16(ip[2]) << 8) | uint16(ip[3])
return false
}
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
} }
@@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
return nil return nil
} }
func (m *localIPManager) IsLocalIP(ip net.IP) bool { func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
m.mu.RLock() m.mu.RLock()
defer m.mu.RUnlock() defer m.mu.RUnlock()
if ipv4 := ip.To4(); ipv4 != nil { if ip.Is4() {
return m.checkBitmapBit(ipv4) return m.checkBitmapBit(ip.AsSlice())
} }
return false return false

View File

@@ -2,90 +2,91 @@ package uspfilter
import ( import (
"net" "net"
"net/netip"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestLocalIPManager(t *testing.T) { func TestLocalIPManager(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
setupAddr iface.WGAddress setupAddr wgaddr.Address
testIP net.IP testIP netip.Addr
expected bool expected bool
}{ }{
{ {
name: "Localhost range", name: "Localhost range",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.2"), testIP: netip.MustParseAddr("127.0.0.2"),
expected: true, expected: true,
}, },
{ {
name: "Localhost standard address", name: "Localhost standard address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.0.0.1"), testIP: netip.MustParseAddr("127.0.0.1"),
expected: true, expected: true,
}, },
{ {
name: "Localhost range edge", name: "Localhost range edge",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("127.255.255.255"), testIP: netip.MustParseAddr("127.255.255.255"),
expected: true, expected: true,
}, },
{ {
name: "Local IP matches", name: "Local IP matches",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.1"), testIP: netip.MustParseAddr("192.168.1.1"),
expected: true, expected: true,
}, },
{ {
name: "Local IP doesn't match", name: "Local IP doesn't match",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("192.168.1.1"), IP: net.ParseIP("192.168.1.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("192.168.1.0"), IP: net.ParseIP("192.168.1.0"),
Mask: net.CIDRMask(24, 32), Mask: net.CIDRMask(24, 32),
}, },
}, },
testIP: net.ParseIP("192.168.1.2"), testIP: netip.MustParseAddr("192.168.1.2"),
expected: false, expected: false,
}, },
{ {
name: "IPv6 address", name: "IPv6 address",
setupAddr: iface.WGAddress{ setupAddr: wgaddr.Address{
IP: net.ParseIP("fe80::1"), IP: net.ParseIP("fe80::1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("fe80::"), IP: net.ParseIP("fe80::"),
Mask: net.CIDRMask(64, 128), Mask: net.CIDRMask(64, 128),
}, },
}, },
testIP: net.ParseIP("fe80::1"), testIP: netip.MustParseAddr("fe80::1"),
expected: false, expected: false,
}, },
} }
@@ -95,7 +96,7 @@ func TestLocalIPManager(t *testing.T) {
manager := newLocalIPManager() manager := newLocalIPManager()
mock := &IFaceMock{ mock := &IFaceMock{
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return tt.setupAddr return tt.setupAddr
}, },
} }
@@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
t.Logf("Testing %d IPs", len(tests)) t.Logf("Testing %d IPs", len(tests))
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.ip, func(t *testing.T) { t.Run(tt.ip, func(t *testing.T) {
result := manager.IsLocalIP(net.ParseIP(tt.ip)) result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
require.Equal(t, tt.expected, result, "IP: %s", tt.ip) require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
}) })
} }

View File

@@ -1,4 +1,4 @@
// Package logger provides a high-performance, non-blocking logger for userspace networking // Package log provides a high-performance, non-blocking logger for userspace networking
package log package log
import ( import (
@@ -13,13 +13,12 @@ import (
) )
const ( const (
maxBatchSize = 1024 * 16 // 16KB max batch size maxBatchSize = 1024 * 16
maxMessageSize = 1024 * 2 // 2KB per message maxMessageSize = 1024 * 2
bufferSize = 1024 * 256 // 256KB ring buffer
defaultFlushInterval = 2 * time.Second defaultFlushInterval = 2 * time.Second
logChannelSize = 1000
) )
// Level represents log severity
type Level uint32 type Level uint32
const ( const (
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
LevelTrace: "TRAC", LevelTrace: "TRAC",
} }
// Logger is a high-performance, non-blocking logger type logMessage struct {
type Logger struct { level Level
output io.Writer format string
level atomic.Uint32 args []any
buffer *ringBuffer
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
// Reusable buffer pool for formatting messages
bufPool sync.Pool
} }
// Logger is a high-performance, non-blocking logger
type Logger struct {
output io.Writer
level atomic.Uint32
msgChannel chan logMessage
shutdown chan struct{}
closeOnce sync.Once
wg sync.WaitGroup
bufPool sync.Pool
}
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
func NewFromLogrus(logrusLogger *log.Logger) *Logger { func NewFromLogrus(logrusLogger *log.Logger) *Logger {
l := &Logger{ l := &Logger{
output: logrusLogger.Out, output: logrusLogger.Out,
buffer: newRingBuffer(bufferSize), msgChannel: make(chan logMessage, logChannelSize),
shutdown: make(chan struct{}), shutdown: make(chan struct{}),
bufPool: sync.Pool{ bufPool: sync.Pool{
New: func() interface{} { New: func() any {
// Pre-allocate buffer for message formatting
b := make([]byte, 0, maxMessageSize) b := make([]byte, 0, maxMessageSize)
return &b return &b
}, },
}, },
} }
logrusLevel := logrusLogger.GetLevel() logrusLevel := logrusLogger.GetLevel()
l.level.Store(uint32(logrusLevel)) l.level.Store(uint32(logrusLevel))
level := levelStrings[Level(logrusLevel)] level := levelStrings[Level(logrusLevel)]
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
return l return l
} }
// SetLevel sets the logging level
func (l *Logger) SetLevel(level Level) { func (l *Logger) SetLevel(level Level) {
l.level.Store(uint32(level)) l.level.Store(uint32(level))
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
} }
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { func (l *Logger) log(level Level, format string, args ...any) {
*buf = (*buf)[:0] select {
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
// Timestamp default:
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
// Level
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
// Message
if len(args) > 0 {
*buf = append(*buf, fmt.Sprintf(format, args...)...)
} else {
*buf = append(*buf, format...)
} }
*buf = append(*buf, '\n')
} }
func (l *Logger) log(level Level, format string, args ...interface{}) { // Error logs a message at error level
bufp := l.bufPool.Get().(*[]byte) func (l *Logger) Error(format string, args ...any) {
l.formatMessage(bufp, level, format, args...)
if len(*bufp) > maxMessageSize {
*bufp = (*bufp)[:maxMessageSize]
}
_, _ = l.buffer.Write(*bufp)
l.bufPool.Put(bufp)
}
func (l *Logger) Error(format string, args ...interface{}) {
if l.level.Load() >= uint32(LevelError) { if l.level.Load() >= uint32(LevelError) {
l.log(LevelError, format, args...) l.log(LevelError, format, args...)
} }
} }
func (l *Logger) Warn(format string, args ...interface{}) { // Warn logs a message at warning level
func (l *Logger) Warn(format string, args ...any) {
if l.level.Load() >= uint32(LevelWarn) { if l.level.Load() >= uint32(LevelWarn) {
l.log(LevelWarn, format, args...) l.log(LevelWarn, format, args...)
} }
} }
func (l *Logger) Info(format string, args ...interface{}) { // Info logs a message at info level
func (l *Logger) Info(format string, args ...any) {
if l.level.Load() >= uint32(LevelInfo) { if l.level.Load() >= uint32(LevelInfo) {
l.log(LevelInfo, format, args...) l.log(LevelInfo, format, args...)
} }
} }
func (l *Logger) Debug(format string, args ...interface{}) { // Debug logs a message at debug level
func (l *Logger) Debug(format string, args ...any) {
if l.level.Load() >= uint32(LevelDebug) { if l.level.Load() >= uint32(LevelDebug) {
l.log(LevelDebug, format, args...) l.log(LevelDebug, format, args...)
} }
} }
func (l *Logger) Trace(format string, args ...interface{}) { // Trace logs a message at trace level
func (l *Logger) Trace(format string, args ...any) {
if l.level.Load() >= uint32(LevelTrace) { if l.level.Load() >= uint32(LevelTrace) {
l.log(LevelTrace, format, args...) l.log(LevelTrace, format, args...)
} }
} }
// worker periodically flushes the buffer func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
*buf = (*buf)[:0]
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
*buf = append(*buf, ' ')
*buf = append(*buf, levelStrings[level]...)
*buf = append(*buf, ' ')
var msg string
if len(args) > 0 {
msg = fmt.Sprintf(format, args...)
} else {
msg = format
}
*buf = append(*buf, msg...)
*buf = append(*buf, '\n')
if len(*buf) > maxMessageSize {
*buf = (*buf)[:maxMessageSize]
}
}
// processMessage handles a single log message and adds it to the buffer
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
bufp := l.bufPool.Get().(*[]byte)
defer l.bufPool.Put(bufp)
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
if len(*buffer)+len(*bufp) > maxBatchSize {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
*buffer = append(*buffer, *bufp...)
}
// flushBuffer writes the accumulated buffer to output
func (l *Logger) flushBuffer(buffer *[]byte) {
if len(*buffer) > 0 {
_, _ = l.output.Write(*buffer)
*buffer = (*buffer)[:0]
}
}
// processBatch processes as many messages as possible without blocking
func (l *Logger) processBatch(buffer *[]byte) {
for len(*buffer) < maxBatchSize {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
default:
return
}
}
}
// handleShutdown manages the graceful shutdown sequence with timeout
func (l *Logger) handleShutdown(buffer *[]byte) {
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer cancel()
for {
select {
case msg := <-l.msgChannel:
l.processMessage(msg, buffer)
case <-ctx.Done():
l.flushBuffer(buffer)
return
}
if len(l.msgChannel) == 0 {
l.flushBuffer(buffer)
return
}
}
}
// worker is the main goroutine that processes log messages
func (l *Logger) worker() { func (l *Logger) worker() {
defer l.wg.Done() defer l.wg.Done()
ticker := time.NewTicker(defaultFlushInterval) ticker := time.NewTicker(defaultFlushInterval)
defer ticker.Stop() defer ticker.Stop()
buf := make([]byte, 0, maxBatchSize) buffer := make([]byte, 0, maxBatchSize)
for { for {
select { select {
case <-l.shutdown: case <-l.shutdown:
l.handleShutdown(&buffer)
return return
case <-ticker.C: case <-ticker.C:
// Read accumulated messages l.flushBuffer(&buffer)
n, _ := l.buffer.Read(buf[:cap(buf)]) case msg := <-l.msgChannel:
if n == 0 { l.processMessage(msg, &buffer)
continue l.processBatch(&buffer)
}
// Write batch
_, _ = l.output.Write(buf[:n])
} }
} }
} }

View File

@@ -0,0 +1,121 @@
package log_test
import (
"context"
"testing"
"time"
"github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
type discard struct{}
func (d *discard) Write(p []byte) (n int, err error) {
return len(p), nil
}
func BenchmarkLogger(b *testing.B) {
simpleMessage := "Connection established"
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4 // TCPStateEstablished
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
protocol := "TCP"
direction := "outbound"
flags := uint16(0x18) // ACK + PSH
sequence := uint32(123456789)
acknowledged := uint32(987654321)
payloadSize := 1460
fragmented := false
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
b.Run("SimpleMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(simpleMessage)
}
})
b.Run("ConntrackMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
b.Run("ComplexMessage", func(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
b.ResetTimer()
for i := 0; i < b.N; i++ {
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
}
})
}
// BenchmarkLoggerParallel tests the logger under concurrent load
func BenchmarkLoggerParallel(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
})
}
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
func BenchmarkLoggerBurst(b *testing.B) {
logger := createTestLogger()
defer cleanupLogger(logger)
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
srcIP := "192.168.1.1"
srcPort := uint16(12345)
dstIP := "10.0.0.1"
dstPort := uint16(443)
state := 4
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < 100; j++ {
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
}
}
}
func createTestLogger() *log.Logger {
logrusLogger := logrus.New()
logrusLogger.SetOutput(&discard{})
logrusLogger.SetLevel(logrus.TraceLevel)
return log.NewFromLogrus(logrusLogger)
}
func cleanupLogger(logger *log.Logger) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_ = logger.Stop(ctx)
}

View File

@@ -1,85 +0,0 @@
package log
import "sync"
// ringBuffer is a simple ring buffer implementation
type ringBuffer struct {
buf []byte
size int
r, w int64 // Read and write positions
mu sync.Mutex
}
func newRingBuffer(size int) *ringBuffer {
return &ringBuffer{
buf: make([]byte, size),
size: size,
}
}
func (r *ringBuffer) Write(p []byte) (n int, err error) {
if len(p) == 0 {
return 0, nil
}
r.mu.Lock()
defer r.mu.Unlock()
if len(p) > r.size {
p = p[:r.size]
}
n = len(p)
// Write data, handling wrap-around
pos := int(r.w % int64(r.size))
writeLen := min(len(p), r.size-pos)
copy(r.buf[pos:], p[:writeLen])
// If we have more data and need to wrap around
if writeLen < len(p) {
copy(r.buf, p[writeLen:])
}
// Update write position
r.w += int64(n)
return n, nil
}
func (r *ringBuffer) Read(p []byte) (n int, err error) {
r.mu.Lock()
defer r.mu.Unlock()
if r.w == r.r {
return 0, nil
}
// Calculate available data accounting for wraparound
available := int(r.w - r.r)
if available < 0 {
available += r.size
}
available = min(available, r.size)
// Limit read to buffer size
toRead := min(available, len(p))
if toRead == 0 {
return 0, nil
}
// Read data, handling wrap-around
pos := int(r.r % int64(r.size))
readLen := min(toRead, r.size-pos)
n = copy(p, r.buf[pos:pos+readLen])
// If we need more data and need to wrap around
if readLen < toRead {
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
}
// Update read position
r.r += int64(n)
return n, nil
}

View File

@@ -1,7 +1,6 @@
package uspfilter package uspfilter
import ( import (
"net"
"net/netip" "net/netip"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -12,25 +11,26 @@ import (
// PeerRule to handle management of rules // PeerRule to handle management of rules
type PeerRule struct { type PeerRule struct {
id string id string
ip net.IP mgmtId []byte
ip netip.Addr
ipLayer gopacket.LayerType ipLayer gopacket.LayerType
matchByIP bool matchByIP bool
protoLayer gopacket.LayerType protoLayer gopacket.LayerType
sPort *firewall.Port sPort *firewall.Port
dPort *firewall.Port dPort *firewall.Port
drop bool drop bool
comment string
udpHook func([]byte) bool udpHook func([]byte) bool
} }
// GetRuleID returns the rule id // ID returns the rule id
func (r *PeerRule) GetRuleID() string { func (r *PeerRule) ID() string {
return r.id return r.id
} }
type RouteRule struct { type RouteRule struct {
id string id string
mgmtId []byte
sources []netip.Prefix sources []netip.Prefix
destination netip.Prefix destination netip.Prefix
proto firewall.Protocol proto firewall.Protocol
@@ -39,7 +39,7 @@ type RouteRule struct {
action firewall.Action action firewall.Action
} }
// GetRuleID returns the rule id // ID returns the rule id
func (r *RouteRule) GetRuleID() string { func (r *RouteRule) ID() string {
return r.id return r.id
} }

View File

@@ -2,7 +2,7 @@ package uspfilter
import ( import (
"fmt" "fmt"
"net" "net/netip"
"time" "time"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -53,8 +53,8 @@ type TraceResult struct {
} }
type PacketTrace struct { type PacketTrace struct {
SourceIP net.IP SourceIP netip.Addr
DestinationIP net.IP DestinationIP netip.Addr
Protocol string Protocol string
SourcePort uint16 SourcePort uint16
DestinationPort uint16 DestinationPort uint16
@@ -72,8 +72,8 @@ type TCPState struct {
} }
type PacketBuilder struct { type PacketBuilder struct {
SrcIP net.IP SrcIP netip.Addr
DstIP net.IP DstIP netip.Addr
Protocol fw.Protocol Protocol fw.Protocol
SrcPort uint16 SrcPort uint16
DstPort uint16 DstPort uint16
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
Version: 4, Version: 4,
TTL: 64, TTL: 64,
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
SrcIP: p.SrcIP, SrcIP: p.SrcIP.AsSlice(),
DstIP: p.DstIP, DstIP: p.DstIP.AsSlice(),
} }
} }
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
return m.traceInbound(packetData, trace, d, srcIP, dstIP) return m.traceInbound(packetData, trace, d, srcIP, dstIP)
} }
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
return trace return trace
} }
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
return trace if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
return trace
}
} }
if !m.handleRouting(trace) { if !m.handleRouting(trace) {
return trace return trace
} }
if m.nativeRouter { if m.nativeRouter.Load() {
return m.handleNativeRouter(trace) return m.handleNativeRouter(trace)
} }
return m.handleRouteACLs(trace, d, srcIP, dstIP) return m.handleRouteACLs(trace, d, srcIP, dstIP)
} }
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
allowed := m.isValidTrackedConnection(d, srcIP, dstIP) allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
msg := "No existing connection found" msg := "No existing connection found"
if allowed { if allowed {
msg = m.buildConntrackStateMessage(d) msg = m.buildConntrackStateMessage(d)
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
return msg return msg
} }
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
if !m.localForwarding { trace.AddResult(StageRouting, "Packet destined for local delivery", true)
trace.AddResult(StageRouting, "Local forwarding disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
strRuleId := "<no id>"
if ruleId != nil {
strRuleId = string(ruleId)
}
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
if blocked {
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
trace.AddResult(StagePeerACL, msg, false)
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
return true return true
} }
trace.AddResult(StageRouting, "Packet destined for local delivery", true) trace.AddResult(StagePeerACL, msg, true)
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
msg := "Allowed by peer ACL rules"
if blocked {
msg = "Blocked by peer ACL rules"
}
trace.AddResult(StagePeerACL, msg, !blocked)
// Handle netstack mode
if m.netstack { if m.netstack {
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked) switch {
case !m.localForwarding:
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
case m.forwarder.Load() != nil:
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
default:
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
}
return true
} }
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked) // In normal mode, packets are allowed through for local delivery
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
return true return true
} }
func (m *Manager) handleRouting(trace *PacketTrace) bool { func (m *Manager) handleRouting(trace *PacketTrace) bool {
if !m.routingEnabled { if !m.routingEnabled.Load() {
trace.AddResult(StageRouting, "Routing disabled", false) trace.AddResult(StageRouting, "Routing disabled", false)
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
return false return false
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
return trace return trace
} }
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
proto := getProtocolFromPacket(d) proto, _ := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
msg := "Allowed by route ACLs" strId := string(id)
if id == nil {
strId = "<no id>"
}
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
if !allowed { if !allowed {
msg = "Blocked by route ACLs" msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
} }
trace.AddResult(StageRouteACL, msg, allowed) trace.AddResult(StageRouteACL, msg, allowed)
if allowed && m.forwarder != nil { if allowed && m.forwarder.Load() != nil {
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
} }
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
// will create or update the connection state // will create or update the connection state
dropped := m.processOutgoingHooks(packetData) dropped := m.processOutgoingHooks(packetData, 0)
if dropped { if dropped {
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
} else { } else {

View File

@@ -0,0 +1,440 @@
package uspfilter
import (
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
t.Logf("Trace results: %v", trace.Results)
actualStages := make([]PacketStage, 0, len(trace.Results))
for _, result := range trace.Results {
actualStages = append(actualStages, result.Stage)
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
}
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
}
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
require.NotEmpty(t, trace.Results, "Trace should have results")
lastResult := trace.Results[len(trace.Results)-1]
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
}
func TestTracePacket(t *testing.T) {
setupTracerTest := func(statefulMode bool) *Manager {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"),
Mask: net.CIDRMask(16, 32),
},
}
},
}
m, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err)
if !statefulMode {
m.stateful = false
}
return m
}
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
builder := &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: protocol,
SrcPort: srcPort,
DstPort: dstPort,
Direction: direction,
}
if protocol == "tcp" {
builder.TCPState = &TCPState{SYN: true}
}
return builder
}
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
return &PacketBuilder{
SrcIP: netip.MustParseAddr(srcIP),
DstIP: netip.MustParseAddr(dstIP),
Protocol: "icmp",
ICMPType: icmpType,
ICMPCode: icmpCode,
Direction: direction,
}
}
testCases := []struct {
name string
setup func(*Manager)
packetBuilder func() *PacketBuilder
expectedStages []PacketStage
expectedAllow bool
}{
{
name: "LocalTraffic_ACLAllowed",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_ACLDenied",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "LocalTraffic_WithForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = true
m.forwarder.Store(&forwarder.Forwarder{})
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "LocalTraffic_WithoutForwarder",
setup: func(m *Manager) {
m.netstack = true
m.localForwarding = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLAllowed",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
m.forwarder.Store(&forwarder.Forwarder{})
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_ACLDenied",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(false)
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "RoutedTraffic_NativeRouter",
setup: func(m *Manager) {
m.routingEnabled.Store(true)
m.nativeRouter.Store(true)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageRouteACL,
StageForwarding,
StageCompleted,
},
expectedAllow: true,
},
{
name: "RoutedTraffic_RoutingDisabled",
setup: func(m *Manager) {
m.routingEnabled.Store(false)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StageCompleted,
},
expectedAllow: false,
},
{
name: "ConnectionTracking_Hit",
setup: func(m *Manager) {
srcIP := netip.MustParseAddr("100.10.0.100")
dstIP := netip.MustParseAddr("1.1.1.1")
srcPort := uint16(12345)
dstPort := uint16(80)
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
},
packetBuilder: func() *PacketBuilder {
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
pb.TCPState = &TCPState{SYN: true, ACK: true}
return pb
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageCompleted,
},
expectedAllow: true,
},
{
name: "OutboundTraffic",
setup: func(m *Manager) {
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
},
expectedStages: []PacketStage{
StageReceived,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPEchoRequest",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "ICMPDestinationUnreachable",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolICMP
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithoutHook",
setup: func(m *Manager) {
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolUDP
port := &fw.Port{Values: []uint16{53}}
action := fw.ActionAccept
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: true,
},
{
name: "UDPTraffic_WithHook",
setup: func(m *Manager) {
hookFunc := func([]byte) bool {
return true
}
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageConntrack,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
{
name: "StatefulDisabled_NoTracking",
setup: func(m *Manager) {
m.stateful = false
ip := net.ParseIP("1.1.1.1")
proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
require.NoError(t, err)
},
packetBuilder: func() *PacketBuilder {
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
},
expectedStages: []PacketStage{
StageReceived,
StageRouting,
StagePeerACL,
StageCompleted,
},
expectedAllow: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
m := setupTracerTest(true)
tc.setup(m)
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
"100.10.0.100 should be recognized as a local IP")
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
"172.17.0.2 should not be recognized as a local IP")
pb := tc.packetBuilder()
trace, err := m.TracePacketFromBuilder(pb)
require.NoError(t, err)
verifyTraceStages(t, trace, tc.expectedStages)
verifyFinalDisposition(t, trace, tc.expectedAllow)
})
}
}

View File

@@ -10,6 +10,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"github.com/google/gopacket" "github.com/google/gopacket"
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
@@ -22,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
@@ -42,6 +44,8 @@ const (
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
) )
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
// RuleSet is a set of rules grouped by a string key // RuleSet is a set of rules grouped by a string key
type RuleSet map[string]PeerRule type RuleSet map[string]PeerRule
@@ -63,9 +67,9 @@ func (r RouteRules) Sort() {
// Manager userspace firewall manager // Manager userspace firewall manager
type Manager struct { type Manager struct {
// outgoingRules is used for hooks only // outgoingRules is used for hooks only
outgoingRules map[string]RuleSet outgoingRules map[netip.Addr]RuleSet
// incomingRules is used for filtering and hooks // incomingRules is used for filtering and hooks
incomingRules map[string]RuleSet incomingRules map[netip.Addr]RuleSet
routeRules RouteRules routeRules RouteRules
wgNetwork *net.IPNet wgNetwork *net.IPNet
decoders sync.Pool decoders sync.Pool
@@ -77,9 +81,9 @@ type Manager struct {
// indicates whether server routes are disabled // indicates whether server routes are disabled
disableServerRoutes bool disableServerRoutes bool
// indicates whether we forward packets not destined for ourselves // indicates whether we forward packets not destined for ourselves
routingEnabled bool routingEnabled atomic.Bool
// indicates whether we leave forwarding and filtering to the native firewall // indicates whether we leave forwarding and filtering to the native firewall
nativeRouter bool nativeRouter atomic.Bool
// indicates whether we track outbound connections // indicates whether we track outbound connections
stateful bool stateful bool
// indicates whether wireguards runs in netstack mode // indicates whether wireguards runs in netstack mode
@@ -92,8 +96,9 @@ type Manager struct {
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder *forwarder.Forwarder forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger logger *nblog.Logger
flowLogger nftypes.FlowLogger
} }
// decoder for packages // decoder for packages
@@ -110,16 +115,16 @@ type decoder struct {
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
return create(iface, nil, disableServerRoutes) return create(iface, nil, disableServerRoutes, flowLogger)
} }
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
if nativeFirewall == nil { if nativeFirewall == nil {
return nil, errors.New("native firewall is nil") return nil, errors.New("native firewall is nil")
} }
mgr, err := create(iface, nativeFirewall, disableServerRoutes) mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -146,7 +151,7 @@ func parseCreateEnv() (bool, bool) {
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding
} }
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv() disableConntrack, enableLocalForwarding := parseCreateEnv()
m := &Manager{ m := &Manager{
@@ -164,17 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
}, },
}, },
nativeFirewall: nativeFirewall, nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet), outgoingRules: make(map[netip.Addr]RuleSet),
incomingRules: make(map[string]RuleSet), incomingRules: make(map[netip.Addr]RuleSet),
wgIface: iface, wgIface: iface,
localipmanager: newLocalIPManager(), localipmanager: newLocalIPManager(),
disableServerRoutes: disableServerRoutes, disableServerRoutes: disableServerRoutes,
routingEnabled: false,
stateful: !disableConntrack, stateful: !disableConntrack,
logger: nblog.NewFromLogrus(log.StandardLogger()), logger: nblog.NewFromLogrus(log.StandardLogger()),
flowLogger: flowLogger,
netstack: netstack.IsEnabled(), netstack: netstack.IsEnabled(),
localForwarding: enableLocalForwarding, localForwarding: enableLocalForwarding,
} }
m.routingEnabled.Store(false)
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
@@ -183,9 +189,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
} }
// netstack needs the forwarder for local traffic // netstack needs the forwarder for local traffic
@@ -206,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
} }
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
if m.forwarder == nil { if m.forwarder.Load() == nil {
return nil return nil
} }
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
@@ -216,6 +222,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
log.Debugf("blocking invalid routed traffic for %s", wgPrefix) log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
if _, err := m.AddRouteFiltering( if _, err := m.AddRouteFiltering(
nil,
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
wgPrefix, wgPrefix,
firewall.ProtocolALL, firewall.ProtocolALL,
@@ -249,20 +256,20 @@ func (m *Manager) determineRouting() error {
switch { switch {
case disableUspRouting: case disableUspRouting:
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is disabled") log.Info("userspace routing is disabled")
case m.disableServerRoutes: case m.disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack // if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("server routes are disabled") log.Info("server routes are disabled")
case forceUserspaceRouter: case forceUserspaceRouter:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing is forced") log.Info("userspace routing is forced")
@@ -270,19 +277,19 @@ func (m *Manager) determineRouting() error {
// if the OS supports routing natively, then we don't need to filter/route ourselves // if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface // netstack mode won't support native routing as there is no interface
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = true m.nativeRouter.Store(true)
log.Info("native routing is enabled") log.Info("native routing is enabled")
default: default:
m.routingEnabled = true m.routingEnabled.Store(true)
m.nativeRouter = false m.nativeRouter.Store(false)
log.Info("userspace routing enabled by default") log.Info("userspace routing enabled by default")
} }
if m.routingEnabled && !m.nativeRouter { if m.routingEnabled.Load() && !m.nativeRouter.Load() {
return m.initForwarder() return m.initForwarder()
} }
@@ -291,24 +298,24 @@ func (m *Manager) determineRouting() error {
// initForwarder initializes the forwarder, it disables routing on errors // initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error { func (m *Manager) initForwarder() error {
if m.forwarder != nil { if m.forwarder.Load() != nil {
return nil return nil
} }
// Only supported in userspace mode as we need to inject packets back into wireguard directly // Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := m.wgIface.GetWGDevice() intf := m.wgIface.GetWGDevice()
if intf == nil { if intf == nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return errors.New("forwarding not supported") return errors.New("forwarding not supported")
} }
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack) forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
if err != nil { if err != nil {
m.routingEnabled = false m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err) return fmt.Errorf("create forwarder: %w", err)
} }
m.forwarder = forwarder m.forwarder.Store(forwarder)
log.Debug("forwarder initialized") log.Debug("forwarder initialized")
@@ -324,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
} }
func (m *Manager) AddNatRule(pair firewall.RouterPair) error { func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair) return m.nativeFirewall.AddNatRule(pair)
} }
@@ -335,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
// RemoveNatRule removes a routing firewall rule // RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair) return m.nativeFirewall.RemoveNatRule(pair)
} }
return nil return nil
@@ -346,25 +353,31 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
// If comment argument is empty firewall manager should set // If comment argument is empty firewall manager should set
// rule ID as comment for the rule // rule ID as comment for the rule
func (m *Manager) AddPeerFiltering( func (m *Manager) AddPeerFiltering(
id []byte,
ip net.IP, ip net.IP,
proto firewall.Protocol, proto firewall.Protocol,
sPort *firewall.Port, sPort *firewall.Port,
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
_ string, _ string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
// TODO: fix in upper layers
i, ok := netip.AddrFromSlice(ip)
if !ok {
return nil, fmt.Errorf("invalid IP: %s", ip)
}
i = i.Unmap()
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, mgmtId: id,
ip: i,
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
matchByIP: true, matchByIP: true,
drop: action == firewall.ActionDrop, drop: action == firewall.ActionDrop,
comment: comment,
} }
if ipNormalized := ip.To4(); ipNormalized != nil { if i.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
r.ip = ipNormalized
} }
if s := r.ip.String(); s == "0.0.0.0" || s == "::" { if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
@@ -389,15 +402,16 @@ func (m *Manager) AddPeerFiltering(
} }
m.mutex.Lock() m.mutex.Lock()
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(RuleSet) m.incomingRules[r.ip] = make(RuleSet)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
m.mutex.Unlock() m.mutex.Unlock()
return []firewall.Rule{&r}, nil return []firewall.Rule{&r}, nil
} }
func (m *Manager) AddRouteFiltering( func (m *Manager) AddRouteFiltering(
id []byte,
sources []netip.Prefix, sources []netip.Prefix,
destination netip.Prefix, destination netip.Prefix,
proto firewall.Protocol, proto firewall.Protocol,
@@ -405,16 +419,15 @@ func (m *Manager) AddRouteFiltering(
dPort *firewall.Port, dPort *firewall.Port,
action firewall.Action, action firewall.Action,
) (firewall.Rule, error) { ) (firewall.Rule, error) {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
} }
m.mutex.Lock()
defer m.mutex.Unlock()
ruleID := uuid.New().String() ruleID := uuid.New().String()
rule := RouteRule{ rule := RouteRule{
// TODO: consolidate these IDs
id: ruleID, id: ruleID,
mgmtId: id,
sources: sources, sources: sources,
destination: destination, destination: destination,
proto: proto, proto: proto,
@@ -423,21 +436,23 @@ func (m *Manager) AddRouteFiltering(
action: action, action: action,
} }
m.mutex.Lock()
m.routeRules = append(m.routeRules, rule) m.routeRules = append(m.routeRules, rule)
m.routeRules.Sort() m.routeRules.Sort()
m.mutex.Unlock()
return &rule, nil return &rule, nil
} }
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter && m.nativeFirewall != nil { if m.nativeRouter.Load() && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule) return m.nativeFirewall.DeleteRouteRule(rule)
} }
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
ruleID := rule.GetRuleID() ruleID := rule.ID()
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
return r.id == ruleID return r.id == ruleID
}) })
@@ -459,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
return fmt.Errorf("delete rule: invalid rule type: %T", rule) return fmt.Errorf("delete rule: invalid rule type: %T", rule)
} }
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok { if _, ok := m.incomingRules[r.ip][r.id]; !ok {
return fmt.Errorf("delete rule: no rule with such id: %v", r.id) return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
} }
delete(m.incomingRules[r.ip.String()], r.id) delete(m.incomingRules[r.ip], r.id)
return nil return nil
} }
@@ -478,14 +493,30 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
// Flush doesn't need to be implemented for this manager // Flush doesn't need to be implemented for this manager
func (m *Manager) Flush() error { return nil } func (m *Manager) Flush() error { return nil }
// AddDNATRule adds a DNAT rule
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
if m.nativeFirewall == nil {
return nil, errNatNotSupported
}
return m.nativeFirewall.AddDNATRule(rule)
}
// DeleteDNATRule deletes a DNAT rule
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
if m.nativeFirewall == nil {
return errNatNotSupported
}
return m.nativeFirewall.DeleteDNATRule(rule)
}
// DropOutgoing filter outgoing packets // DropOutgoing filter outgoing packets
func (m *Manager) DropOutgoing(packetData []byte) bool { func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
return m.processOutgoingHooks(packetData) return m.processOutgoingHooks(packetData, size)
} }
// DropIncoming filter incoming packets // DropIncoming filter incoming packets
func (m *Manager) DropIncoming(packetData []byte) bool { func (m *Manager) DropIncoming(packetData []byte, size int) bool {
return m.dropFilter(packetData) return m.dropFilter(packetData, size)
} }
// UpdateLocalIPs updates the list of local IPs // UpdateLocalIPs updates the list of local IPs
@@ -493,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
return m.localipmanager.UpdateLocalIPs(m.wgIface) return m.localipmanager.UpdateLocalIPs(m.wgIface)
} }
func (m *Manager) processOutgoingHooks(packetData []byte) bool { func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -509,52 +537,37 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0])
return false return false
} }
// Track all protocols if stateful mode is enabled if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
if m.stateful { return true
switch d.decoded[1] {
case layers.LayerTypeUDP:
m.trackUDPOutbound(d, srcIP, dstIP)
case layers.LayerTypeTCP:
m.trackTCPOutbound(d, srcIP, dstIP)
case layers.LayerTypeICMPv4:
m.trackICMPOutbound(d, srcIP, dstIP)
}
} }
// Process UDP hooks even if stateful mode is disabled if m.stateful {
if d.decoded[1] == layers.LayerTypeUDP { m.trackOutbound(d, srcIP, dstIP, size)
return m.checkUDPHooks(d, dstIP, packetData)
} }
return false return false
} }
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) { func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
switch d.decoded[0] { switch d.decoded[0] {
case layers.LayerTypeIPv4: case layers.LayerTypeIPv4:
return d.ip4.SrcIP, d.ip4.DstIP src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
return src, dst
case layers.LayerTypeIPv6: case layers.LayerTypeIPv6:
return d.ip6.SrcIP, d.ip6.DstIP src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
return src, dst
default: default:
return nil, nil return netip.Addr{}, netip.Addr{}
} }
} }
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(
srcIP,
dstIP,
uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort),
flags,
)
}
func getTCPFlags(tcp *layers.TCP) uint8 { func getTCPFlags(tcp *layers.TCP) uint8 {
var flags uint8 var flags uint8
if tcp.SYN { if tcp.SYN {
@@ -578,45 +591,70 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
m.udpTracker.TrackOutbound( transport := d.decoded[1]
srcIP, switch transport {
dstIP, case layers.LayerTypeUDP:
uint16(d.udp.SrcPort), m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
uint16(d.udp.DstPort), case layers.LayerTypeTCP:
) flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
}
} }
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool { func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} { transport := d.decoded[1]
if rules, exists := m.outgoingRules[ipKey]; exists { switch transport {
for _, rule := range rules { case layers.LayerTypeUDP:
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
return rule.udpHook(packetData) case layers.LayerTypeTCP:
} flags := getTCPFlags(&d.tcp)
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
case layers.LayerTypeICMPv4:
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
}
}
// udpHooksDrop checks if any UDP hooks should drop the packet
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
// Check specific destination IP first
if rules, exists := m.outgoingRules[dstIP]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
} }
} }
} }
return false
}
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { // Check IPv4 unspecified address
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest { if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
m.icmpTracker.TrackOutbound( for _, rule := range rules {
srcIP, if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
dstIP, return rule.udpHook(packetData)
d.icmp4.Id, }
d.icmp4.Seq, }
)
} }
// Check IPv6 unspecified address
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
for _, rule := range rules {
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
return rule.udpHook(packetData)
}
}
}
return false
} }
// dropFilter implements filtering logic for incoming packets. // dropFilter implements filtering logic for incoming packets.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) dropFilter(packetData []byte) bool { func (m *Manager) dropFilter(packetData []byte, size int) bool {
m.mutex.RLock()
defer m.mutex.RUnlock()
d := m.decoders.Get().(*decoder) d := m.decoders.Get().(*decoder)
defer m.decoders.Put(d) defer m.decoders.Put(d)
@@ -625,19 +663,19 @@ func (m *Manager) dropFilter(packetData []byte) bool {
} }
srcIP, dstIP := m.extractIPs(d) srcIP, dstIP := m.extractIPs(d)
if srcIP == nil { if !srcIP.IsValid() {
m.logger.Error("Unknown network layer: %v", d.decoded[0]) m.logger.Error("Unknown network layer: %v", d.decoded[0])
return true return true
} }
// For all inbound traffic, first check if it matches a tracked connection. // For all inbound traffic, first check if it matches a tracked connection.
// This must happen before any other filtering because the packets are statefully tracked. // This must happen before any other filtering because the packets are statefully tracked.
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
return false return false
} }
if m.localipmanager.IsLocalIP(dstIP) { if m.localipmanager.IsLocalIP(dstIP) {
return m.handleLocalTraffic(d, srcIP, dstIP, packetData) return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
} }
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
@@ -645,10 +683,29 @@ func (m *Manager) dropFilter(packetData []byte) bool {
// handleLocalTraffic handles local traffic. // handleLocalTraffic handles local traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) { ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s", if blocked {
srcIP, dstIP) _, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d)
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
RxPackets: 1,
RxBytes: uint64(size),
})
return true return true
} }
@@ -657,6 +714,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
return m.handleNetstackLocalTraffic(packetData) return m.handleNetstackLocalTraffic(packetData)
} }
// track inbound packets to get the correct direction and session id for flows
m.trackInbound(d, srcIP, dstIP, ruleID, size)
return false return false
} }
@@ -666,12 +726,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
return false return false
} }
if m.forwarder == nil { if m.forwarder.Load() == nil {
m.logger.Trace("Dropping local packet (forwarder not initialized)") m.logger.Trace("Dropping local packet (forwarder not initialized)")
return true return true
} }
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject local packet: %v", err) m.logger.Error("Failed to inject local packet: %v", err)
} }
@@ -681,30 +741,43 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
// handleRoutedTraffic handles routed traffic. // handleRoutedTraffic handles routed traffic.
// If it returns true, the packet should be dropped. // If it returns true, the packet should be dropped.
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
// Drop if routing is disabled // Drop if routing is disabled
if !m.routingEnabled { if !m.routingEnabled.Load() {
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
srcIP, dstIP) srcIP, dstIP)
return true return true
} }
// Pass to native stack if native router is enabled or forced // Pass to native stack if native router is enabled or forced
if m.nativeRouter { if m.nativeRouter.Load() {
return false return false
} }
proto := getProtocolFromPacket(d) proto, pnum := getProtocolFromPacket(d)
srcPort, dstPort := getPortsFromPacket(d) srcPort, dstPort := getPortsFromPacket(d)
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) { if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v", m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
srcIP, srcPort, dstIP, dstPort, proto) ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
m.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: uuid.New(),
Type: nftypes.TypeDrop,
RuleID: ruleID,
Direction: nftypes.Ingress,
Protocol: pnum,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
// TODO: icmp type/code
})
return true return true
} }
// Let forwarder handle the packet if it passed route ACLs // Let forwarder handle the packet if it passed route ACLs
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
m.logger.Error("Failed to inject incoming packet: %v", err) m.logger.Error("Failed to inject incoming packet: %v", err)
} }
@@ -712,16 +785,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
return true return true
} }
func getProtocolFromPacket(d *decoder) firewall.Protocol { func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return firewall.ProtocolTCP return firewall.ProtocolTCP, nftypes.TCP
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
return firewall.ProtocolUDP return firewall.ProtocolUDP, nftypes.UDP
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return firewall.ProtocolICMP return firewall.ProtocolICMP, nftypes.ICMP
default: default:
return firewall.ProtocolALL return firewall.ProtocolALL, nftypes.ProtocolUnknown
} }
} }
@@ -749,7 +822,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
return true return true
} }
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
switch d.decoded[1] { switch d.decoded[1] {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
return m.tcpTracker.IsValidInbound( return m.tcpTracker.IsValidInbound(
@@ -758,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
uint16(d.tcp.SrcPort), uint16(d.tcp.SrcPort),
uint16(d.tcp.DstPort), uint16(d.tcp.DstPort),
getTCPFlags(&d.tcp), getTCPFlags(&d.tcp),
size,
) )
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
@@ -766,6 +840,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
dstIP, dstIP,
uint16(d.udp.SrcPort), uint16(d.udp.SrcPort),
uint16(d.udp.DstPort), uint16(d.udp.DstPort),
size,
) )
case layers.LayerTypeICMPv4: case layers.LayerTypeICMPv4:
@@ -773,8 +848,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
srcIP, srcIP,
dstIP, dstIP,
d.icmp4.Id, d.icmp4.Id,
d.icmp4.Seq,
d.icmp4.TypeCode.Type(), d.icmp4.TypeCode.Type(),
size,
) )
// TODO: ICMPv6 // TODO: ICMPv6
@@ -794,25 +869,27 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
icmpType == layers.ICMPv4TypeTimeExceeded icmpType == layers.ICMPv4TypeTimeExceeded
} }
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
m.mutex.RLock()
defer m.mutex.RUnlock()
if m.isSpecialICMP(d) { if m.isSpecialICMP(d) {
return false return nil, false
} }
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
return filter return mgmtId, filter
} }
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
return filter return mgmtId, filter
} }
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok { if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
return filter return mgmtId, filter
} }
// Default policy: DROP ALL // Default policy: DROP ALL
return true return nil, true
} }
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool { func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
@@ -832,15 +909,15 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
return false return false
} }
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) { func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
payloadLayer := d.decoded[1] payloadLayer := d.decoded[1]
for _, rule := range rules { for _, rule := range rules {
if rule.matchByIP && !ip.Equal(rule.ip) { if rule.matchByIP && ip.Compare(rule.ip) != 0 {
continue continue
} }
if rule.protoLayer == layerTypeAll { if rule.protoLayer == layerTypeAll {
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
if payloadLayer != rule.protoLayer { if payloadLayer != rule.protoLayer {
@@ -850,39 +927,36 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
switch payloadLayer { switch payloadLayer {
case layers.LayerTypeTCP: case layers.LayerTypeTCP:
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) { if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
case layers.LayerTypeUDP: case layers.LayerTypeUDP:
// if rule has UDP hook (and if we are here we match this rule) // if rule has UDP hook (and if we are here we match this rule)
// we ignore rule.drop and call this hook // we ignore rule.drop and call this hook
if rule.udpHook != nil { if rule.udpHook != nil {
return rule.udpHook(packetData), true return rule.mgmtId, rule.udpHook(packetData), true
} }
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) { if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
return rule.drop, true return rule.mgmtId, rule.drop, true
} }
} }
return false, false return nil, false, false
} }
// routeACLsPass returns treu if the packet is allowed by the route ACLs // routeACLsPass returns true if the packet is allowed by the route ACLs
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
m.mutex.RLock() m.mutex.RLock()
defer m.mutex.RUnlock() defer m.mutex.RUnlock()
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
for _, rule := range m.routeRules { for _, rule := range m.routeRules {
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
return rule.action == firewall.ActionAccept return rule.mgmtId, rule.action == firewall.ActionAccept
} }
} }
return false return nil, false
} }
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
@@ -922,36 +996,32 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not // Hook function returns flag which indicates should be the matched package dropped or not
func (m *Manager) AddUDPPacketHook( func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
) string {
r := PeerRule{ r := PeerRule{
id: uuid.New().String(), id: uuid.New().String(),
ip: ip, ip: ip,
protoLayer: layers.LayerTypeUDP, protoLayer: layers.LayerTypeUDP,
dPort: &firewall.Port{Values: []uint16{dPort}}, dPort: &firewall.Port{Values: []uint16{dPort}},
ipLayer: layers.LayerTypeIPv6, ipLayer: layers.LayerTypeIPv6,
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
udpHook: hook, udpHook: hook,
} }
if ip.To4() != nil { if ip.Is4() {
r.ipLayer = layers.LayerTypeIPv4 r.ipLayer = layers.LayerTypeIPv4
} }
m.mutex.Lock() m.mutex.Lock()
if in { if in {
if _, ok := m.incomingRules[r.ip.String()]; !ok { if _, ok := m.incomingRules[r.ip]; !ok {
m.incomingRules[r.ip.String()] = make(map[string]PeerRule) m.incomingRules[r.ip] = make(map[string]PeerRule)
} }
m.incomingRules[r.ip.String()][r.id] = r m.incomingRules[r.ip][r.id] = r
} else { } else {
if _, ok := m.outgoingRules[r.ip.String()]; !ok { if _, ok := m.outgoingRules[r.ip]; !ok {
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) m.outgoingRules[r.ip] = make(map[string]PeerRule)
} }
m.outgoingRules[r.ip.String()][r.id] = r m.outgoingRules[r.ip][r.id] = r
} }
m.mutex.Unlock() m.mutex.Unlock()
return r.id return r.id
@@ -999,20 +1069,21 @@ func (m *Manager) DisableRouting() error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
if m.forwarder == nil { fwder := m.forwarder.Load()
if fwder == nil {
return nil return nil
} }
m.routingEnabled = false m.routingEnabled.Store(false)
m.nativeRouter = false m.nativeRouter.Store(false)
// don't stop forwarder if in use by netstack // don't stop forwarder if in use by netstack
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
return nil return nil
} }
m.forwarder.Stop() fwder.Stop()
m.forwarder = nil m.forwarder.Store(nil)
log.Debug("forwarder stopped") log.Debug("forwarder stopped")

View File

@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: false, stateful: false,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Single rule allowing all traffic // Single rule allowing all traffic
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, _, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
fw.ActionAccept, "", "allow all")
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Baseline: Single 'allow all' rule without connection tracking", desc: "Baseline: Single 'allow all' rule without connection tracking",
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Add explicit rules matching return traffic pattern // Add explicit rules matching return traffic pattern
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
ip := generateRandomIPs(1)[0] ip := generateRandomIPs(1)[0]
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP, _, err := m.AddPeerFiltering(
nil,
ip,
fw.ProtocolTCP,
&fw.Port{Values: []uint16{uint16(1024 + i)}}, &fw.Port{Values: []uint16{uint16(1024 + i)}},
&fw.Port{Values: []uint16{80}}, &fw.Port{Values: []uint16{80}},
fw.ActionAccept, "", "explicit return") fw.ActionAccept,
"",
)
require.NoError(b, err) require.NoError(b, err)
} }
}, },
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
stateful: true, stateful: true,
setupFunc: func(m *Manager) { setupFunc: func(m *Manager) {
// Add some basic rules but rely on state for established connections // Add some basic rules but rely on state for established connections
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil, _, err := m.AddPeerFiltering(
fw.ActionDrop, "", "default drop") nil,
net.ParseIP("0.0.0.0"),
fw.ProtocolTCP,
nil,
nil,
fw.ActionDrop,
"",
)
require.NoError(b, err) require.NoError(b, err)
}, },
desc: "Connection tracking with established connections", desc: "Connection tracking with established connections",
@@ -158,7 +169,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
// For stateful scenarios, establish the connection // For stateful scenarios, establish the connection
if sc.stateful { if sc.stateful {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Measure inbound packet processing // Measure inbound packet processing
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -203,7 +214,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
outbound := generatePacket(b, srcIPs[i], dstIPs[i], outbound := generatePacket(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, layers.IPProtocolTCP) uint16(1024+i), 80, layers.IPProtocolTCP)
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
// Test packet // Test packet
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP) testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
// First establish our test connection // First establish our test connection
manager.processOutgoingHooks(testOut) manager.processOutgoingHooks(testOut, 0)
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(testIn) manager.dropFilter(testIn, 0)
} }
}) })
} }
@@ -251,7 +262,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP) inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
if sc.established { if sc.established {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -450,7 +461,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
// For stateful cases and established connections // For stateful cases and established connections
if !strings.Contains(sc.name, "allow_non_wg") || if !strings.Contains(sc.name, "allow_non_wg") ||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") { (strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
manager.processOutgoingHooks(outbound) manager.processOutgoingHooks(outbound, 0)
// For TCP post-handshake, simulate full handshake // For TCP post-handshake, simulate full handshake
if sc.state == "post_handshake" { if sc.state == "post_handshake" {
// SYN // SYN
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn)) syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
} }
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
manager.dropFilter(inbound) manager.dropFilter(inbound, 0)
} }
}) })
} }
@@ -577,7 +588,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Initial SYN // Initial SYN
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
// SYN-ACK // SYN-ACK
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
// ACK // ACK
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Prepare test packets simulating bidirectional traffic // Prepare test packets simulating bidirectional traffic
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
// Simulate bidirectional traffic // Simulate bidirectional traffic
// First outbound data // First outbound data
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
// Then inbound response - this is what we're actually measuring // Then inbound response - this is what we're actually measuring
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
} }
@@ -668,7 +676,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
// Single rule to allow all return traffic from port 80 // Single rule to allow all return traffic from port 80
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Connection establishment // Connection establishment
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
// Data transfer // Data transfer
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
// Connection teardown // Connection teardown
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
} }
@@ -787,7 +792,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
// Setup initial state based on scenario // Setup initial state based on scenario
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
for i := 0; i < sc.connCount; i++ { for i := 0; i < sc.connCount; i++ {
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPSyn)) uint16(1024+i), 80, uint16(conntrack.TCPSyn))
manager.processOutgoingHooks(syn) manager.processOutgoingHooks(syn, 0)
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
manager.dropFilter(synack) manager.dropFilter(synack, 0)
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
uint16(1024+i), 80, uint16(conntrack.TCPAck)) uint16(1024+i), 80, uint16(conntrack.TCPAck))
manager.processOutgoingHooks(ack) manager.processOutgoingHooks(ack, 0)
} }
// Pre-generate test packets // Pre-generate test packets
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
counter++ counter++
// Simulate bidirectional traffic // Simulate bidirectional traffic
manager.processOutgoingHooks(outPackets[connIdx]) manager.processOutgoingHooks(outPackets[connIdx], 0)
manager.dropFilter(inPackets[connIdx]) manager.dropFilter(inPackets[connIdx], 0)
} }
}) })
}) })
@@ -875,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}) })
if sc.rules { if sc.rules {
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, _, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
&fw.Port{Values: []uint16{80}},
nil,
fw.ActionAccept, "", "return traffic")
require.NoError(b, err) require.NoError(b, err)
} }
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
p := patterns[connIdx] p := patterns[connIdx]
// Full connection lifecycle // Full connection lifecycle
manager.processOutgoingHooks(p.syn) manager.processOutgoingHooks(p.syn, 0)
manager.dropFilter(p.synAck) manager.dropFilter(p.synAck, 0)
manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.ack, 0)
manager.processOutgoingHooks(p.request) manager.processOutgoingHooks(p.request, 0)
manager.dropFilter(p.response) manager.dropFilter(p.response, 0)
manager.processOutgoingHooks(p.finClient) manager.processOutgoingHooks(p.finClient, 0)
manager.dropFilter(p.ackServer) manager.dropFilter(p.ackServer, 0)
manager.dropFilter(p.finServer) manager.dropFilter(p.finServer, 0)
manager.processOutgoingHooks(p.ackClient) manager.processOutgoingHooks(p.ackClient, 0)
} }
}) })
}) })
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
for _, r := range rules { for _, r := range rules {
_, err := manager.AddRouteFiltering( _, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
r.sources,
r.dest,
r.proto,
nil,
r.port,
fw.ActionAccept,
)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
for _, tc := range cases { for _, tc := range cases {
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
} }
} }

View File

@@ -12,9 +12,9 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
func TestPeerACLFiltering(t *testing.T) { func TestPeerACLFiltering(t *testing.T) {
@@ -26,15 +26,15 @@ func TestPeerACLFiltering(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, manager) require.NotNil(t, manager)
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
t.Run("Implicit DROP (no rules)", func(t *testing.T) { t.Run("Implicit DROP (no rules)", func(t *testing.T) {
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.True(t, isDropped, "Packet should be dropped when no rules exist") require.True(t, isDropped, "Packet should be dropped when no rules exist")
}) })
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
rules, err := manager.AddPeerFiltering( rules, err := manager.AddPeerFiltering(
nil,
net.ParseIP(tc.ruleIP), net.ParseIP(tc.ruleIP),
tc.ruleProto, tc.ruleProto,
tc.ruleSrcPort, tc.ruleSrcPort,
tc.ruleDstPort, tc.ruleDstPort,
tc.ruleAction, tc.ruleAction,
"", "",
tc.name,
) )
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, rules) require.NotEmpty(t, rules)
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
}) })
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
isDropped := manager.DropIncoming(packet) isDropped := manager.DropIncoming(packet, 0)
require.Equal(t, tc.shouldBeBlocked, isDropped) require.Equal(t, tc.shouldBeBlocked, isDropped)
}) })
} }
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: localIP, IP: localIP,
Network: wgNet, Network: wgNet,
} }
@@ -302,12 +302,12 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}, },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NoError(tb, err) require.NoError(tb, err)
require.NotNil(tb, manager) require.NotNil(tb, manager)
require.True(tb, manager.routingEnabled) require.True(tb, manager.routingEnabled.Load())
require.False(tb, manager.nativeRouter) require.False(tb, manager.nativeRouter.Load())
tb.Cleanup(func() { tb.Cleanup(func() {
require.NoError(tb, manager.Close(nil)) require.NoError(tb, manager.Close(nil))
@@ -803,6 +803,7 @@ func TestRouteACLFiltering(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil,
tc.rule.sources, tc.rule.sources,
tc.rule.dest, tc.rule.dest,
tc.rule.proto, tc.rule.proto,
@@ -817,12 +818,12 @@ func TestRouteACLFiltering(t *testing.T) {
require.NoError(t, manager.DeleteRouteRule(rule)) require.NoError(t, manager.DeleteRouteRule(rule))
}) })
srcIP := net.ParseIP(tc.srcIP) srcIP := netip.MustParseAddr(tc.srcIP)
dstIP := net.ParseIP(tc.dstIP) dstIP := netip.MustParseAddr(tc.dstIP)
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
// to the forwarder // to the forwarder
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
require.Equal(t, tc.shouldPass, isAllowed) require.Equal(t, tc.shouldPass, isAllowed)
}) })
} }
@@ -985,6 +986,7 @@ func TestRouteACLOrder(t *testing.T) {
var rules []fw.Rule var rules []fw.Rule
for _, r := range tc.rules { for _, r := range tc.rules {
rule, err := manager.AddRouteFiltering( rule, err := manager.AddRouteFiltering(
nil,
r.sources, r.sources,
r.dest, r.dest,
r.proto, r.proto,
@@ -1004,10 +1006,10 @@ func TestRouteACLOrder(t *testing.T) {
}) })
for i, p := range tc.packets { for i, p := range tc.packets {
srcIP := net.ParseIP(p.srcIP) srcIP := netip.MustParseAddr(p.srcIP)
dstIP := net.ParseIP(p.dstIP) dstIP := netip.MustParseAddr(p.dstIP)
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) _, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
} }
}) })

View File

@@ -1,8 +1,10 @@
package uspfilter package uspfilter
import ( import (
"context"
"fmt" "fmt"
"net" "net"
"net/netip"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -16,15 +18,17 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
) )
var logger = log.NewFromLogrus(logrus.StandardLogger()) var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type IFaceMock struct { type IFaceMock struct {
SetFilterFunc func(device.PacketFilter) error SetFilterFunc func(device.PacketFilter) error
AddressFunc func() iface.WGAddress AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice GetDeviceFunc func() *device.FilteredDevice
} }
@@ -50,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
return i.SetFilterFunc(iface) return i.SetFilterFunc(iface)
} }
func (i *IFaceMock) Address() iface.WGAddress { func (i *IFaceMock) Address() wgaddr.Address {
if i.AddressFunc == nil { if i.AddressFunc == nil {
return iface.WGAddress{} return wgaddr.Address{}
} }
return i.AddressFunc() return i.AddressFunc()
} }
@@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -92,9 +96,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -116,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
} }
ip := net.ParseIP("192.168.1.1") ip := netip.MustParseAddr("192.168.1.1")
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule 2"
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok { if _, ok := m.incomingRules[ip][r.ID()]; !ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
} }
for _, r := range rule2 { for _, r := range rule2 {
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok { if _, ok := m.incomingRules[ip][r.ID()]; ok {
t.Errorf("rule2 is not in the incomingRules") t.Errorf("rule2 is not in the incomingRules")
} }
} }
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name string name string
in bool in bool
expDir fw.RuleDirection expDir fw.RuleDirection
ip net.IP ip netip.Addr
dPort uint16 dPort uint16
hook func([]byte) bool hook func([]byte) bool
expectedID string expectedID string
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Outgoing UDP Packet Hook", name: "Test Outgoing UDP Packet Hook",
in: false, in: false,
expDir: fw.RuleDirectionOUT, expDir: fw.RuleDirectionOUT,
ip: net.IPv4(10, 168, 0, 1), ip: netip.MustParseAddr("10.168.0.1"),
dPort: 8000, dPort: 8000,
hook: func([]byte) bool { return true }, hook: func([]byte) bool { return true },
}, },
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
name: "Test Incoming UDP Packet Hook", name: "Test Incoming UDP Packet Hook",
in: true, in: true,
expDir: fw.RuleDirectionIN, expDir: fw.RuleDirectionIN,
ip: net.IPv6loopback, ip: netip.MustParseAddr("::1"),
dPort: 9000, dPort: 9000,
hook: func([]byte) bool { return false }, hook: func([]byte) bool { return false },
}, },
@@ -187,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
var addedRule PeerRule var addedRule PeerRule
if tt.in { if tt.in {
if len(manager.incomingRules[tt.ip.String()]) != 1 { if len(manager.incomingRules[tt.ip]) != 1 {
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
return return
} }
for _, rule := range manager.incomingRules[tt.ip.String()] { for _, rule := range manager.incomingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} else { } else {
@@ -206,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules)) t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
return return
} }
for _, rule := range manager.outgoingRules[tt.ip.String()] { for _, rule := range manager.outgoingRules[tt.ip] {
addedRule = rule addedRule = rule
} }
} }
if !tt.ip.Equal(addedRule.ip) { if tt.ip.Compare(addedRule.ip) != 0 {
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip) t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
return return
} }
@@ -236,7 +238,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -246,9 +248,8 @@ func TestManagerReset(t *testing.T) {
proto := fw.ProtocolTCP proto := fw.ProtocolTCP
port := &fw.Port{Values: []uint16{80}} port := &fw.Port{Values: []uint16{80}}
action := fw.ActionDrop action := fw.ActionDrop
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment) _, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -268,8 +269,8 @@ func TestManagerReset(t *testing.T) {
func TestNotMatchByIP(t *testing.T) { func TestNotMatchByIP(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("100.10.0.100"), IP: net.ParseIP("100.10.0.100"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("100.10.0.0"), IP: net.ParseIP("100.10.0.0"),
@@ -279,7 +280,7 @@ func TestNotMatchByIP(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false) m, err := Create(ifaceMock, false, flowLogger)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -292,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
ip := net.ParseIP("0.0.0.0") ip := net.ParseIP("0.0.0.0")
proto := fw.ProtocolUDP proto := fw.ProtocolUDP
action := fw.ActionAccept action := fw.ActionAccept
comment := "Test rule"
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment) _, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
if err != nil { if err != nil {
t.Errorf("failed to add filtering: %v", err) t.Errorf("failed to add filtering: %v", err)
return return
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
return return
} }
if m.dropFilter(buf.Bytes()) { if m.dropFilter(buf.Bytes(), 0) {
t.Errorf("expected packet to be accepted") t.Errorf("expected packet to be accepted")
return return
} }
@@ -347,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface, false) manager, err := Create(iface, false, flowLogger)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
// Add a UDP packet hook // Add a UDP packet hook
hookFunc := func(data []byte) bool { return true } hookFunc := func(data []byte) bool { return true }
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc) hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
// Assert the hook is added by finding it in the manager's outgoing rules // Assert the hook is added by finding it in the manager's outgoing rules
found := false found := false
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -401,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
Mask: net.CIDRMask(16, 32), Mask: net.CIDRMask(16, 32),
} }
manager.udpTracker.Close() manager.udpTracker.Close()
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}() }()
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
hookCalled := false hookCalled := false
hookID := manager.AddUDPPacketHook( hookID := manager.AddUDPPacketHook(
false, false,
net.ParseIP("100.10.0.100"), netip.MustParseAddr("100.10.0.100"),
53, 53,
func([]byte) bool { func([]byte) bool {
hookCalled = true hookCalled = true
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Test hook gets called // Test hook gets called
result := manager.processOutgoingHooks(buf.Bytes()) result := manager.processOutgoingHooks(buf.Bytes(), 0)
require.True(t, result) require.True(t, result)
require.True(t, hookCalled) require.True(t, hookCalled)
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
err = gopacket.SerializeLayers(buf, opts, ipv4) err = gopacket.SerializeLayers(buf, opts, ipv4)
require.NoError(t, err) require.NoError(t, err)
result = manager.processOutgoingHooks(buf.Bytes()) result = manager.processOutgoingHooks(buf.Bytes(), 0)
require.False(t, result) require.False(t, result)
} }
@@ -479,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false) manager, err := Create(ifaceMock, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
start := time.Now() start := time.Now()
for i := 0; i < testMax; i++ { for i := 0; i < testMax; i++ {
port := &fw.Port{Values: []uint16{uint16(1000 + i)}} port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic") _, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
require.NoError(t, err, "failed to add rule") require.NoError(t, err, "failed to add rule")
} }
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false) }, false, flowLogger)
require.NoError(t, err) require.NoError(t, err)
manager.wgNetwork = &net.IPNet{ manager.wgNetwork = &net.IPNet{
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
manager.decoders = sync.Pool{ manager.decoders = sync.Pool{
New: func() any { New: func() any {
d := &decoder{ d := &decoder{
@@ -534,8 +534,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
}() }()
// Set up packet parameters // Set up packet parameters
srcIP := net.ParseIP("100.10.0.1") srcIP := netip.MustParseAddr("100.10.0.1")
dstIP := net.ParseIP("100.10.0.100") dstIP := netip.MustParseAddr("100.10.0.100")
srcPort := uint16(51334) srcPort := uint16(51334)
dstPort := uint16(53) dstPort := uint16(53)
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
outboundIPv4 := &layers.IPv4{ outboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: srcIP, SrcIP: srcIP.AsSlice(),
DstIP: dstIP, DstIP: dstIP.AsSlice(),
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
outboundUDP := &layers.UDP{ outboundUDP := &layers.UDP{
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Process outbound packet and verify connection tracking // Process outbound packet and verify connection tracking
drop := manager.DropOutgoing(outboundBuf.Bytes()) drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
require.False(t, drop, "Initial outbound packet should not be dropped") require.False(t, drop, "Initial outbound packet should not be dropped")
// Verify connection was tracked // Verify connection was tracked
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
require.True(t, exists, "Connection should be tracked after outbound packet") require.True(t, exists, "Connection should be tracked after outbound packet")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match") require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match") require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
require.Equal(t, srcPort, conn.SourcePort, "Source port should match") require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
require.Equal(t, dstPort, conn.DestPort, "Destination port should match") require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
inboundIPv4 := &layers.IPv4{ inboundIPv4 := &layers.IPv4{
TTL: 64, TTL: 64,
Version: 4, Version: 4,
SrcIP: dstIP, // Original destination is now source SrcIP: dstIP.AsSlice(), // Original destination is now source
DstIP: srcIP, // Original source is now destination DstIP: srcIP.AsSlice(), // Original source is now destination
Protocol: layers.IPProtocolUDP, Protocol: layers.IPProtocolUDP,
} }
inboundUDP := &layers.UDP{ inboundUDP := &layers.UDP{
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
for _, cp := range checkPoints { for _, cp := range checkPoints {
time.Sleep(cp.sleep) time.Sleep(cp.sleep)
drop = manager.dropFilter(inboundBuf.Bytes()) drop = manager.dropFilter(inboundBuf.Bytes(), 0)
require.Equal(t, cp.shouldAllow, !drop, cp.description) require.Equal(t, cp.shouldAllow, !drop, cp.description)
// If the connection should still be valid, verify it exists // If the connection should still be valid, verify it exists
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
} }
// Create a new outbound connection for invalid tests // Create a new outbound connection for invalid tests
drop = manager.processOutgoingHooks(outboundBuf.Bytes()) drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
require.False(t, drop, "Second outbound packet should not be dropped") require.False(t, drop, "Second outbound packet should not be dropped")
for _, tc := range invalidCases { for _, tc := range invalidCases {
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Verify the invalid packet is dropped // Verify the invalid packet is dropped
drop = manager.dropFilter(testBuf.Bytes()) drop = manager.dropFilter(testBuf.Bytes(), 0)
require.True(t, drop, tc.description) require.True(t, drop, tc.description)
}) })
} }

View File

@@ -13,6 +13,8 @@ import (
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type RecvMessage struct { type RecvMessage struct {
@@ -51,9 +53,10 @@ type ICEBind struct {
muUDPMux sync.Mutex muUDPMux sync.Mutex
udpMux *UniversalUDPMuxDefault udpMux *UniversalUDPMuxDefault
address wgaddr.Address
} }
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind { func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind) b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
ib := &ICEBind{ ib := &ICEBind{
StdNetBind: b, StdNetBind: b,
@@ -63,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
endpoints: make(map[netip.Addr]net.Conn), endpoints: make(map[netip.Addr]net.Conn),
closedChan: make(chan struct{}), closedChan: make(chan struct{}),
closed: true, closed: true,
address: address,
} }
rc := receiverCreator{ rc := receiverCreator{
@@ -142,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
s.udpMux = NewUniversalUDPMuxDefault( s.udpMux = NewUniversalUDPMuxDefault(
UniversalUDPMuxParams{ UniversalUDPMuxParams{
UDPConn: conn, UDPConn: conn,
Net: s.transportNet, Net: s.transportNet,
FilterFn: s.filterFn, FilterFn: s.filterFn,
WGAddress: s.address,
}, },
) )
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {

View File

@@ -17,6 +17,8 @@ import (
"github.com/pion/logging" "github.com/pion/logging"
"github.com/pion/stun/v2" "github.com/pion/stun/v2"
"github.com/pion/transport/v3" "github.com/pion/transport/v3"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// FilterFn is a function that filters out candidates based on the address. // FilterFn is a function that filters out candidates based on the address.
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
XORMappedAddrCacheTTL time.Duration XORMappedAddrCacheTTL time.Duration
Net transport.Net Net transport.Net
FilterFn FilterFn FilterFn FilterFn
WGAddress wgaddr.Address
} }
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux // NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
mux: m, mux: m,
logger: params.Logger, logger: params.Logger,
filterFn: params.FilterFn, filterFn: params.FilterFn,
address: params.WGAddress,
} }
// embed UDPMux // embed UDPMux
@@ -118,6 +122,7 @@ type udpConn struct {
filterFn FilterFn filterFn FilterFn
// TODO: reset cache on route changes // TODO: reset cache on route changes
addrCache sync.Map addrCache sync.Map
address wgaddr.Address
} }
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) { func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
return nil return nil
} }
if u.address.Network.Contains(a.AsSlice()) {
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
if isRouted, prefix, err := u.filterFn(a); err != nil { if isRouted, prefix, err := u.filterFn(a); err != nil {
log.Errorf("Failed to check if address %s is routed: %v", addr, err) log.Errorf("Failed to check if address %s is routed: %v", addr, err)
} else { } else {

View File

@@ -9,13 +9,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create() (device.WGConfigurer, error) Create() (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform // WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
type WGTunDevice struct { type WGTunDevice struct {
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -31,7 +32,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error { func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
return t.name return t.name
} }
func (t *WGTunDevice) WgAddress() WGAddress { func (t *WGTunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,11 +13,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -29,7 +30,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -2,6 +2,7 @@ package device
import ( import (
"net" "net"
"net/netip"
"sync" "sync"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
@@ -10,16 +11,16 @@ import (
// PacketFilter interface for firewall abilities // PacketFilter interface for firewall abilities
type PacketFilter interface { type PacketFilter interface {
// DropOutgoing filter outgoing packets from host to external destinations // DropOutgoing filter outgoing packets from host to external destinations
DropOutgoing(packetData []byte) bool DropOutgoing(packetData []byte, size int) bool
// DropIncoming filter incoming packets from external sources to host // DropIncoming filter incoming packets from external sources to host
DropIncoming(packetData []byte) bool DropIncoming(packetData []byte, size int) bool
// AddUDPPacketHook calls hook when UDP packet from given direction matched // AddUDPPacketHook calls hook when UDP packet from given direction matched
// //
// Hook function returns flag which indicates should be the matched package dropped or not. // Hook function returns flag which indicates should be the matched package dropped or not.
// Hook function receives raw network packet data as argument. // Hook function receives raw network packet data as argument.
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
// RemovePacketHook removes hook by ID // RemovePacketHook removes hook by ID
RemovePacketHook(hookID string) error RemovePacketHook(hookID string) error
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) { if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...) sizes = append(sizes[:i], sizes[i+1:]...)
n-- n--
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if !filter.DropIncoming(buf[offset:]) { if !filter.DropIncoming(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++ dropped++
} }

View File

@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
tun.EXPECT().Write(mockBufs, 0).Return(0, nil) tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropIncoming(gomock.Any()).Return(true) filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
return 1, nil return 1, nil
}) })
filter := mocks.NewMockPacketFilter(ctrl) filter := mocks.NewMockPacketFilter(ctrl)
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true) filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
wrapped := newDeviceFilter(tun) wrapped := newDeviceFilter(tun)
wrapped.filter = filter wrapped.filter = filter

View File

@@ -14,11 +14,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
iceBind *bind.ICEBind iceBind *bind.ICEBind
@@ -30,7 +31,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }
func (t *TunDevice) UpdateAddr(addr WGAddress) error { func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
// todo implement // todo implement
return nil return nil
} }

View File

@@ -14,12 +14,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/sharedsock" "github.com/netbirdio/netbird/sharedsock"
) )
type TunKernelDevice struct { type TunKernelDevice struct {
name string name string
address WGAddress address wgaddr.Address
wgPort int wgPort int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
filterFn bind.FilterFn filterFn bind.FilterFn
} }
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice { func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
return &TunKernelDevice{ return &TunKernelDevice{
ctx: ctx, ctx: ctx,
@@ -99,9 +100,10 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return nil, err return nil, err
} }
bindParams := bind.UniversalUDPMuxParams{ bindParams := bind.UniversalUDPMuxParams{
UDPConn: rawSock, UDPConn: rawSock,
Net: t.transportNet, Net: t.transportNet,
FilterFn: t.filterFn, FilterFn: t.filterFn,
WGAddress: t.address,
} }
mux := bind.NewUniversalUDPMuxDefault(bindParams) mux := bind.NewUniversalUDPMuxDefault(bindParams)
go mux.ReadFromConn(t.ctx) go mux.ReadFromConn(t.ctx)
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return t.udpMux, nil return t.udpMux, nil
} }
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error { func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
return closErr return closErr
} }
func (t *TunKernelDevice) WgAddress() WGAddress { func (t *TunKernelDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,12 +13,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nbnet "github.com/netbirdio/netbird/util/net" nbnet "github.com/netbirdio/netbird/util/net"
) )
type TunNetstackDevice struct { type TunNetstackDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
net *netstack.Net net *netstack.Net
} }
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice { func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
return &TunNetstackDevice{ return &TunNetstackDevice{
name: name, name: name,
address: address, address: address,
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error { func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
return nil return nil
} }
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
return nil return nil
} }
func (t *TunNetstackDevice) WgAddress() WGAddress { func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -12,11 +12,12 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type USPDevice struct { type USPDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -28,7 +29,7 @@ type USPDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice { func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
log.Infof("using userspace bind mode") log.Infof("using userspace bind mode")
return &USPDevice{ return &USPDevice{
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *USPDevice) UpdateAddr(address WGAddress) error { func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
return nil return nil
} }
func (t *USPDevice) WgAddress() WGAddress { func (t *USPDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -13,13 +13,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}" const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
type TunDevice struct { type TunDevice struct {
name string name string
address WGAddress address wgaddr.Address
port int port int
key string key string
mtu int mtu int
@@ -32,7 +33,7 @@ type TunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice { func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
return &TunDevice{ return &TunDevice{
name: name, name: name,
address: address, address: address,
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
return udpMux, nil return udpMux, nil
} }
func (t *TunDevice) UpdateAddr(address WGAddress) error { func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
t.address = address t.address = address
return t.assignAddr() return t.assignAddr()
} }
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
} }
return nil return nil
} }
func (t *TunDevice) WgAddress() WGAddress { func (t *TunDevice) WgAddress() wgaddr.Address {
return t.address return t.address
} }

View File

@@ -6,6 +6,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/freebsd" "github.com/netbirdio/netbird/client/iface/freebsd"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
link, err := freebsd.LinkByName(l.name) link, err := freebsd.LinkByName(l.name)
if err != nil { if err != nil {
return fmt.Errorf("link by name: %w", err) return fmt.Errorf("link by name: %w", err)

View File

@@ -8,6 +8,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type wgLink struct { type wgLink struct {
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
return nil return nil
} }
func (l *wgLink) assignAddr(address WGAddress) error { func (l *wgLink) assignAddr(address wgaddr.Address) error {
//delete existing addresses //delete existing addresses
list, err := netlink.AddrList(l, 0) list, err := netlink.AddrList(l, 0)
if err != nil { if err != nil {

View File

@@ -7,13 +7,14 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
type WGTunDevice interface { type WGTunDevice interface {
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error) Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(address WGAddress) error UpdateAddr(address wgaddr.Address) error
WgAddress() WGAddress WgAddress() wgaddr.Address
DeviceName() string DeviceName() string
Close() error Close() error
FilteredDevice() *device.FilteredDevice FilteredDevice() *device.FilteredDevice

View File

@@ -19,6 +19,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -28,8 +29,6 @@ const (
WgInterfaceDefault = configurer.WgInterfaceDefault WgInterfaceDefault = configurer.WgInterfaceDefault
) )
type WGAddress = device.WGAddress
type wgProxyFactory interface { type wgProxyFactory interface {
GetProxy() wgproxy.Proxy GetProxy() wgproxy.Proxy
Free() error Free() error
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
} }
// Address returns the interface address // Address returns the interface address
func (w *WGIface) Address() device.WGAddress { func (w *WGIface) Address() wgaddr.Address {
return w.tun.WgAddress() return w.tun.WgAddress()
} }
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
w.mu.Lock() w.mu.Lock()
defer w.mu.Unlock() defer w.mu.Unlock()
addr, err := device.ParseWGAddress(newAddr) addr, err := wgaddr.ParseWGAddress(newAddr)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,17 +3,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,

View File

@@ -6,17 +6,18 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -5,17 +5,18 @@ package iface
import ( import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace := &WGIface{ wgIFace := &WGIface{
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd), tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),

View File

@@ -8,12 +8,13 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{} wgIFace := &WGIface{}
if netstack.IsEnabled() { if netstack.IsEnabled() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr()) wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
return wgIFace, nil return wgIFace, nil
} }
if device.ModuleTunIsLoaded() { if device.ModuleTunIsLoaded() {
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind) wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
wgIFace.userspaceBind = true wgIFace.userspaceBind = true
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind) wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)

View File

@@ -4,16 +4,17 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/netstack"
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
// NewWGIFace Creates a new WireGuard interface instance // NewWGIFace Creates a new WireGuard interface instance
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) { func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgAddress, err := device.ParseWGAddress(opts.Address) wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn) iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
var tun WGTunDevice var tun WGTunDevice
if netstack.IsEnabled() { if netstack.IsEnabled() {

View File

@@ -6,6 +6,7 @@ package mocks
import ( import (
net "net" net "net"
"net/netip"
reflect "reflect" reflect "reflect"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
} }
// AddUDPPacketHook mocks base method. // AddUDPPacketHook mocks base method.
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string { func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(string)
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
} }
// DropIncoming mocks base method. // DropIncoming mocks base method.
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool { func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropIncoming", arg0) ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropIncoming indicates an expected call of DropIncoming. // DropIncoming indicates an expected call of DropIncoming.
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
} }
// DropOutgoing mocks base method. // DropOutgoing mocks base method.
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool { func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DropOutgoing", arg0) ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(bool)
return ret0 return ret0
} }
// DropOutgoing indicates an expected call of DropOutgoing. // DropOutgoing indicates an expected call of DropOutgoing.
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call { func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
} }
// RemovePacketHook mocks base method. // RemovePacketHook mocks base method.

View File

@@ -1,29 +1,29 @@
package device package wgaddr
import ( import (
"fmt" "fmt"
"net" "net"
) )
// WGAddress WireGuard parsed address // Address WireGuard parsed address
type WGAddress struct { type Address struct {
IP net.IP IP net.IP
Network *net.IPNet Network *net.IPNet
} }
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address // ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
func ParseWGAddress(address string) (WGAddress, error) { func ParseWGAddress(address string) (Address, error) {
ip, network, err := net.ParseCIDR(address) ip, network, err := net.ParseCIDR(address)
if err != nil { if err != nil {
return WGAddress{}, err return Address{}, err
} }
return WGAddress{ return Address{
IP: ip, IP: ip,
Network: network, Network: network,
}, nil }, nil
} }
func (addr WGAddress) String() string { func (addr Address) String() string {
maskSize, _ := addr.Network.Mask.Size() maskSize, _ := addr.Network.Mask.Size()
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize) return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
} }

View File

@@ -22,6 +22,8 @@
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}" !define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}" !define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
Unicode True Unicode True
###################################################################### ######################################################################
@@ -68,6 +70,9 @@ ShowInstDetails Show
!insertmacro MUI_PAGE_DIRECTORY !insertmacro MUI_PAGE_DIRECTORY
; Custom page for autostart checkbox
Page custom AutostartPage AutostartPageLeave
!insertmacro MUI_PAGE_INSTFILES !insertmacro MUI_PAGE_INSTFILES
!insertmacro MUI_PAGE_FINISH !insertmacro MUI_PAGE_FINISH
@@ -80,8 +85,36 @@ ShowInstDetails Show
!insertmacro MUI_LANGUAGE "English" !insertmacro MUI_LANGUAGE "English"
; Variables for autostart option
Var AutostartCheckbox
Var AutostartEnabled
###################################################################### ######################################################################
; Function to create the autostart options page
Function AutostartPage
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
nsDialogs::Create 1018
Pop $0
${If} $0 == error
Abort
${EndIf}
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
Pop $AutostartCheckbox
${NSD_Check} $AutostartCheckbox ; Default to checked
StrCpy $AutostartEnabled "1" ; Default to enabled
nsDialogs::Show
FunctionEnd
; Function to handle leaving the autostart page
Function AutostartPageLeave
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
FunctionEnd
Function GetAppFromCommand Function GetAppFromCommand
Exch $1 Exch $1
Push $2 Push $2
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}" WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
EnVar::SetHKLM EnVar::SetHKLM
EnVar::AddValueEx "path" "$INSTDIR" EnVar::AddValueEx "path" "$INSTDIR"
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall' ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
# kill ui client # kill ui client
ExecWait `taskkill /im ${UI_APP_EXE}.exe` ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
# wait the service uninstall take unblock the executable # wait the service uninstall take unblock the executable
Sleep 3000 Sleep 3000

View File

@@ -12,7 +12,7 @@ import (
type RuleID string type RuleID string
func (r RuleID) GetRuleID() string { func (r RuleID) ID() string {
return string(r) return string(r)
} }

View File

@@ -240,12 +240,12 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
dPorts := convertPortInfo(rule.PortInfo) dPorts := convertPortInfo(rule.PortInfo)
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action) addedRule, err := d.firewall.AddRouteFiltering(rule.Id, sources, destination, protocol, nil, dPorts, action)
if err != nil { if err != nil {
return "", fmt.Errorf("add route rule: %w", err) return "", fmt.Errorf("add route rule: %w", err)
} }
return id.RuleID(addedRule.GetRuleID()), nil return id.RuleID(addedRule.ID()), nil
} }
func (d *DefaultManager) protoRuleToFirewallRule( func (d *DefaultManager) protoRuleToFirewallRule(
@@ -281,7 +281,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
} }
} }
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "") ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok { if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
return ruleID, rulesPair, nil return ruleID, rulesPair, nil
} }
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
var rules []firewall.Rule var rules []firewall.Rule
switch r.Direction { switch r.Direction {
case mgmProto.RuleDirection_IN: case mgmProto.RuleDirection_IN:
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addInRules(r.Id, ip, protocol, port, action, ipsetName)
case mgmProto.RuleDirection_OUT: case mgmProto.RuleDirection_OUT:
// TODO: Remove this soon. Outbound rules are obsolete. // TODO: Remove this soon. Outbound rules are obsolete.
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already // We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "") rules, err = d.addOutRules(r.Id, ip, protocol, port, action, ipsetName)
default: default:
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule") return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
} }
@@ -322,14 +322,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
} }
func (d *DefaultManager) addInRules( func (d *DefaultManager) addInRules(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment) rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -338,18 +338,18 @@ func (d *DefaultManager) addInRules(
} }
func (d *DefaultManager) addOutRules( func (d *DefaultManager) addOutRules(
id []byte,
ip net.IP, ip net.IP,
protocol firewall.Protocol, protocol firewall.Protocol,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
ipsetName string, ipsetName string,
comment string,
) ([]firewall.Rule, error) { ) ([]firewall.Rule, error) {
if shouldSkipInvertedRule(protocol, port) { if shouldSkipInvertedRule(protocol, port) {
return nil, nil return nil, nil
} }
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment) rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
if err != nil { if err != nil {
return nil, fmt.Errorf("add firewall rule: %w", err) return nil, fmt.Errorf("add firewall rule: %w", err)
} }
@@ -364,9 +364,8 @@ func (d *DefaultManager) getPeerRuleID(
direction int, direction int,
port *firewall.Port, port *firewall.Port,
action firewall.Action, action firewall.Action,
comment string,
) id.RuleID { ) id.RuleID {
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
if port != nil { if port != nil {
idStr += port.String() idStr += port.String()
} }
@@ -515,7 +514,7 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
for _, rules := range newRulePairs { for _, rules := range newRulePairs {
for _, rule := range rules { for _, rule := range rules {
if err := d.firewall.DeletePeerRule(rule); err != nil { if err := d.firewall.DeletePeerRule(rule); err != nil {
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err) log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
} }
} }
} }

View File

@@ -1,6 +1,7 @@
package acl package acl
import ( import (
"context"
"net" "net"
"testing" "testing"
@@ -8,11 +9,14 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
mgmProto "github.com/netbirdio/netbird/management/proto" mgmProto "github.com/netbirdio/netbird/management/proto"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
func TestDefaultManager(t *testing.T) { func TestDefaultManager(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
FirewallRules: []*mgmProto.FirewallRule{ FirewallRules: []*mgmProto.FirewallRule{
@@ -45,14 +49,14 @@ func TestDefaultManager(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return
@@ -74,7 +78,7 @@ func TestDefaultManager(t *testing.T) {
t.Run("add extra rules", func(t *testing.T) { t.Run("add extra rules", func(t *testing.T) {
existedPairs := map[string]struct{}{} existedPairs := map[string]struct{}{}
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
existedPairs[id.GetRuleID()] = struct{}{} existedPairs[id.ID()] = struct{}{}
} }
// remove first rule // remove first rule
@@ -100,7 +104,7 @@ func TestDefaultManager(t *testing.T) {
// check that old rule was removed // check that old rule was removed
previousCount := 0 previousCount := 0
for id := range acl.peerRulesPairs { for id := range acl.peerRulesPairs {
if _, ok := existedPairs[id.GetRuleID()]; ok { if _, ok := existedPairs[id.ID()]; ok {
previousCount++ previousCount++
} }
} }
@@ -339,14 +343,14 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
} }
ifaceMock.EXPECT().Name().Return("lo").AnyTimes() ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
ifaceMock.EXPECT().Address().Return(iface.WGAddress{ ifaceMock.EXPECT().Address().Return(wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
// we receive one rule from the management so for testing purposes ignore it // we receive one rule from the management so for testing purposes ignore it
fw, err := firewall.NewFirewall(ifaceMock, nil, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
if err != nil { if err != nil {
t.Errorf("create firewall: %v", err) t.Errorf("create firewall: %v", err)
return return

View File

@@ -10,8 +10,8 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
iface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// MockIFaceMapper is a mock of IFaceMapper interface. // MockIFaceMapper is a mock of IFaceMapper interface.
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
} }
// Address mocks base method. // Address mocks base method.
func (m *MockIFaceMapper) Address() iface.WGAddress { func (m *MockIFaceMapper) Address() wgaddr.Address {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Address") ret := m.ctrl.Call(m, "Address")
ret0, _ := ret[0].(iface.WGAddress) ret0, _ := ret[0].(wgaddr.Address)
return ret0 return ret0
} }

View File

@@ -61,7 +61,7 @@ func NewConnectClient(
} }
// Run with main logic. // Run with main logic.
func (c *ConnectClient) Run(runningChan chan error) error { func (c *ConnectClient) Run(runningChan chan struct{}) error {
return c.run(MobileDependency{}, runningChan) return c.run(MobileDependency{}, runningChan)
} }
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
return c.run(mobileDependency, nil) return c.run(mobileDependency, nil)
} }
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error { func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
rec := c.statusRecorder rec := c.statusRecorder
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
} }
defer c.statusRecorder.ClientStop() defer c.statusRecorder.ClientStop()
runningChanOpen := true
operation := func() error { operation := func() error {
// if context cancelled we not start new backoff cycle // if context cancelled we not start new backoff cycle
if c.isContextCancelled() { if c.ctx.Err() != nil {
return nil return nil
} }
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress()) log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected) state.Set(StatusConnected)
if runningChan != nil && runningChanOpen { if runningChan != nil {
runningChan <- nil select {
close(runningChan) case runningChan <- struct{}{}:
runningChanOpen = false default:
}
} }
<-engineCtx.Done() <-engineCtx.Done()
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
return nil return nil
} }
func (c *ConnectClient) isContextCancelled() bool {
select {
case <-c.ctx.Done():
return true
default:
return false
}
}
// SetNetworkMapPersistence enables or disables network map persistence. // SetNetworkMapPersistence enables or disables network map persistence.
// When enabled, the last received network map will be stored and can be retrieved // When enabled, the last received network map will be stored and can be retrieved
// through the Engine's getLatestNetworkMap method. When disabled, any stored // through the Engine's getLatestNetworkMap method. When disabled, any stored

View File

@@ -22,6 +22,8 @@ import (
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
pfmock "github.com/netbirdio/netbird/client/iface/mocks" pfmock "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
@@ -29,6 +31,8 @@ import (
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
) )
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
type mocWGIface struct { type mocWGIface struct {
filter device.PacketFilter filter device.PacketFilter
} }
@@ -37,9 +41,9 @@ func (w *mocWGIface) Name() string {
panic("implement me") panic("implement me")
} }
func (w *mocWGIface) Address() iface.WGAddress { func (w *mocWGIface) Address() wgaddr.Address {
ip, network, _ := net.ParseCIDR("100.66.100.0/24") ip, network, _ := net.ParseCIDR("100.66.100.0/24")
return iface.WGAddress{ return wgaddr.Address{
IP: ip, IP: ip,
Network: network, Network: network,
} }
@@ -455,7 +459,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
} }
packetfilter := pfmock.NewMockPacketFilter(ctrl) packetfilter := pfmock.NewMockPacketFilter(ctrl)
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes() packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()) packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
packetfilter.EXPECT().RemovePacketHook(gomock.Any()) packetfilter.EXPECT().RemovePacketHook(gomock.Any())
packetfilter.EXPECT().SetNetwork(ipNet) packetfilter.EXPECT().SetNetwork(ipNet)
@@ -916,7 +920,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err return nil, err
} }
pf, err := uspfilter.Create(wgIface, false) pf, err := uspfilter.Create(wgIface, false, flowLogger)
if err != nil { if err != nil {
t.Fatalf("failed to create uspfilter: %v", err) t.Fatalf("failed to create uspfilter: %v", err)
return nil, err return nil, err

View File

@@ -2,7 +2,7 @@ package dns
import ( import (
"fmt" "fmt"
"net" "net/netip"
"sync" "sync"
"github.com/google/gopacket" "github.com/google/gopacket"
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
return true return true
} }
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil ip, err := netip.ParseAddr(s.runtimeIP)
if err != nil {
return "", fmt.Errorf("parse runtime ip: %w", err)
}
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
} }

View File

@@ -5,15 +5,15 @@ package dns
import ( import (
"net" "net"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter

View File

@@ -1,15 +1,15 @@
package dns package dns
import ( import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
) )
// WGIface defines subset methods of interface required for manager // WGIface defines subset methods of interface required for manager
type WGIface interface { type WGIface interface {
Name() string Name() string
Address() iface.WGAddress Address() wgaddr.Address
IsUserspaceBind() bool IsUserspaceBind() bool
GetFilter() device.PacketFilter GetFilter() device.PacketFilter
GetDevice() *device.FilteredDevice GetDevice() *device.FilteredDevice

View File

@@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
return nil return nil
} }
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "") dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil { if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err) log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err return err

View File

@@ -25,7 +25,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
@@ -33,6 +33,9 @@ import (
"github.com/netbirdio/netbird/client/internal/acl" "github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/dnsfwd"
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/netflow"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/networkmonitor" "github.com/netbirdio/netbird/client/internal/networkmonitor"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/peer/guard" "github.com/netbirdio/netbird/client/internal/peer/guard"
@@ -169,10 +172,11 @@ type Engine struct {
statusRecorder *peer.Status statusRecorder *peer.Status
firewall manager.Manager firewall firewallManager.Manager
routeManager routemanager.Manager routeManager routemanager.Manager
acl acl.Manager acl acl.Manager
dnsForwardMgr *dnsfwd.Manager dnsForwardMgr *dnsfwd.Manager
ingressGatewayMgr *ingressgw.Manager
dnsServer dns.Server dnsServer dns.Server
@@ -187,6 +191,7 @@ type Engine struct {
persistNetworkMap bool persistNetworkMap bool
latestNetworkMap *mgmProto.NetworkMap latestNetworkMap *mgmProto.NetworkMap
connSemaphore *semaphoregroup.SemaphoreGroup connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
} }
// Peer is an instance of the Connection Peer // Peer is an instance of the Connection Peer
@@ -266,6 +271,13 @@ func (e *Engine) Stop() error {
// stop/restore DNS first so dbus and friends don't complain because of a missing interface // stop/restore DNS first so dbus and friends don't complain because of a missing interface
e.stopDNSServer() e.stopDNSServer()
if e.ingressGatewayMgr != nil {
if err := e.ingressGatewayMgr.Close(); err != nil {
log.Warnf("failed to cleanup forward rules: %v", err)
}
e.ingressGatewayMgr = nil
}
if e.routeManager != nil { if e.routeManager != nil {
e.routeManager.Stop(e.stateManager) e.routeManager.Stop(e.stateManager)
} }
@@ -299,6 +311,12 @@ func (e *Engine) Stop() error {
time.Sleep(500 * time.Millisecond) time.Sleep(500 * time.Millisecond)
e.close() e.close()
// stop flow manager after wg interface is gone
if e.flowManager != nil {
e.flowManager.Close()
}
log.Infof("stopped Netbird Engine") log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
@@ -333,6 +351,10 @@ func (e *Engine) Start() error {
} }
e.wgInterface = wgIface e.wgInterface = wgIface
// start flow manager right after interface creation
publicKey := e.config.WgPrivateKey.PublicKey()
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
if e.config.RosenpassEnabled { if e.config.RosenpassEnabled {
log.Infof("rosenpass is enabled") log.Infof("rosenpass is enabled")
if e.config.RosenpassPermissive { if e.config.RosenpassPermissive {
@@ -439,7 +461,7 @@ func (e *Engine) createFirewall() error {
} }
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
if err != nil || e.firewall == nil { if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
return nil return nil
@@ -469,16 +491,16 @@ func (e *Engine) initFirewall() error {
} }
rosenpassPort := e.rpManager.GetAddress().Port rosenpassPort := e.rpManager.GetAddress().Port
port := manager.Port{Values: []uint16{uint16(rosenpassPort)}} port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
// this rule is static and will be torn down on engine down by the firewall manager // this rule is static and will be torn down on engine down by the firewall manager
if _, err := e.firewall.AddPeerFiltering( if _, err := e.firewall.AddPeerFiltering(
nil,
net.IP{0, 0, 0, 0}, net.IP{0, 0, 0, 0},
manager.ProtocolUDP, firewallManager.ProtocolUDP,
nil, nil,
&port, &port,
manager.ActionAccept, firewallManager.ActionAccept,
"",
"", "",
); err != nil { ); err != nil {
log.Errorf("failed to allow rosenpass interface traffic: %v", err) log.Errorf("failed to allow rosenpass interface traffic: %v", err)
@@ -503,12 +525,13 @@ func (e *Engine) blockLanAccess() {
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0) v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
for _, network := range toBlock { for _, network := range toBlock {
if _, err := e.firewall.AddRouteFiltering( if _, err := e.firewall.AddRouteFiltering(
nil,
[]netip.Prefix{v4}, []netip.Prefix{v4},
network, network,
manager.ProtocolALL, firewallManager.ProtocolALL,
nil, nil,
nil, nil,
manager.ActionDrop, firewallManager.ActionDrop,
); err != nil { ); err != nil {
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err)) merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
} }
@@ -633,25 +656,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
stunTurn = append(stunTurn, e.TURNs...) stunTurn = append(stunTurn, e.TURNs...)
e.stunTurn.Store(stunTurn) e.stunTurn.Store(stunTurn)
relayMsg := wCfg.GetRelay() err = e.handleRelayUpdate(wCfg.GetRelay())
if relayMsg != nil { if err != nil {
// when we receive token we expect valid address list too return err
c := &auth.Token{ }
Payload: relayMsg.GetTokenPayload(),
Signature: relayMsg.GetTokenSignature(),
}
if err := e.relayManager.UpdateToken(c); err != nil {
log.Errorf("failed to update relay token: %v", err)
return fmt.Errorf("update relay token: %w", err)
}
e.relayManager.UpdateServerURLs(relayMsg.Urls) err = e.handleFlowUpdate(wCfg.GetFlow())
if err != nil {
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled. return fmt.Errorf("handle the flow configuration: %w", err)
// We can ignore all errors because the guard will manage the reconnection retries.
_ = e.relayManager.Serve()
} else {
e.relayManager.UpdateServerURLs(nil)
} }
// todo update signal // todo update signal
@@ -682,6 +694,55 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
return nil return nil
} }
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
if update != nil {
// when we receive token we expect valid address list too
c := &auth.Token{
Payload: update.GetTokenPayload(),
Signature: update.GetTokenSignature(),
}
if err := e.relayManager.UpdateToken(c); err != nil {
return fmt.Errorf("update relay token: %w", err)
}
e.relayManager.UpdateServerURLs(update.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
_ = e.relayManager.Serve()
} else {
e.relayManager.UpdateServerURLs(nil)
}
return nil
}
func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
if config == nil {
return nil
}
flowConfig, err := toFlowLoggerConfig(config)
if err != nil {
return err
}
return e.flowManager.Update(flowConfig)
}
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
if config.GetInterval() == nil {
return nil, errors.New("flow interval is nil")
}
return &nftypes.FlowConfig{
Enabled: config.GetEnabled(),
Counters: config.GetCounters(),
URL: config.GetUrl(),
TokenPayload: config.GetTokenPayload(),
TokenSignature: config.GetTokenSignature(),
Interval: config.GetInterval().AsDuration(),
}, nil
}
// updateChecksIfNew updates checks if there are changes and sync new meta with management // updateChecksIfNew updates checks if there are changes and sync new meta with management
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error { func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
// if checks are equal, we skip the update // if checks are equal, we skip the update
@@ -912,6 +973,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
log.Errorf("failed to update clientRoutes, err: %v", err) log.Errorf("failed to update clientRoutes, err: %v", err)
} }
// Ingress forward rules
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
log.Errorf("failed to update forward rules, err: %v", err)
}
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
e.updateOfflinePeers(networkMap.GetOfflinePeers()) e.updateOfflinePeers(networkMap.GetOfflinePeers())
@@ -1482,7 +1548,7 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
} }
// GetFirewallManager returns the firewall manager // GetFirewallManager returns the firewall manager
func (e *Engine) GetFirewallManager() manager.Manager { func (e *Engine) GetFirewallManager() firewallManager.Manager {
return e.firewall return e.firewall
} }
@@ -1575,16 +1641,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns) return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
} }
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() { func (e *Engine) restartEngine() {
log.Info("restarting engine") e.syncMsgMux.Lock()
CtxGetState(e.ctx).Set(StatusConnecting) defer e.syncMsgMux.Unlock()
if err := e.Stop(); err != nil { if e.ctx.Err() != nil {
log.Errorf("Failed to stop engine: %v", err) return
} }
log.Info("restarting engine")
CtxGetState(e.ctx).Set(StatusConnecting)
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection) _ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
log.Infof("cancelling client, engine will be recreated") log.Infof("cancelling client context, engine will be recreated")
e.clientCancel() e.clientCancel()
} }
@@ -1596,34 +1665,17 @@ func (e *Engine) startNetworkMonitor() {
e.networkMonitor = networkmonitor.New() e.networkMonitor = networkmonitor.New()
go func() { go func() {
var mu sync.Mutex if err := e.networkMonitor.Listen(e.ctx); err != nil {
var debounceTimer *time.Timer if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
// Start the network monitor with a callback, Start will block until the monitor is stopped, return
// a network change is detected, or an error occurs on start up
err := e.networkMonitor.Start(e.ctx, func() {
// This function is called when a network change is detected
mu.Lock()
defer mu.Unlock()
if debounceTimer != nil {
log.Infof("Network monitor: detected network change, reset debounceTimer")
debounceTimer.Stop()
} }
log.Errorf("network monitor error: %v", err)
// Set a new timer to debounce rapid network changes return
debounceTimer = time.AfterFunc(2*time.Second, func() {
// This function is called after the debounce period
mu.Lock()
defer mu.Unlock()
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
})
})
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
log.Errorf("Network monitor: %v", err)
} }
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
}() }()
} }
@@ -1770,6 +1822,74 @@ func (e *Engine) Address() (netip.Addr, error) {
return ip.Unmap(), nil return ip.Unmap(), nil
} }
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules")
return nil
}
if len(rules) == 0 {
if e.ingressGatewayMgr == nil {
return nil
}
err := e.ingressGatewayMgr.Close()
e.ingressGatewayMgr = nil
e.statusRecorder.SetIngressGwMgr(nil)
return err
}
if e.ingressGatewayMgr == nil {
mgr := ingressgw.NewManager(e.firewall)
e.ingressGatewayMgr = mgr
e.statusRecorder.SetIngressGwMgr(mgr)
}
var merr *multierror.Error
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
for _, rule := range rules {
proto, err := convertToFirewallProtocol(rule.GetProtocol())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
continue
}
dstPortInfo, err := convertPortInfo(rule.GetDestinationPort())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid destination port '%v': %w", rule.GetDestinationPort(), err))
continue
}
translateIP, err := convertToIP(rule.GetTranslatedAddress())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("failed to convert translated address '%s': %w", rule.GetTranslatedAddress(), err))
continue
}
translatePort, err := convertPortInfo(rule.GetTranslatedPort())
if err != nil {
merr = multierror.Append(merr, fmt.Errorf("invalid translate port '%v': %w", rule.GetTranslatedPort(), err))
continue
}
forwardRule := firewallManager.ForwardRule{
Protocol: proto,
DestinationPort: *dstPortInfo,
TranslatedAddress: translateIP,
TranslatedPort: *translatePort,
}
forwardingRules = append(forwardingRules, forwardRule)
}
log.Infof("updating forwarding rules: %d", len(forwardingRules))
if err := e.ingressGatewayMgr.Update(forwardingRules); err != nil {
log.Errorf("failed to update forwarding rules: %v", err)
}
return nberrors.FormatErrorOrNil(merr)
}
// isChecksEqual checks if two slices of checks are equal. // isChecksEqual checks if two slices of checks are equal.
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool { func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
for _, check := range checks { for _, check := range checks {

View File

@@ -31,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
@@ -44,6 +45,7 @@ import (
mgmtProto "github.com/netbirdio/netbird/management/proto" mgmtProto "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
@@ -74,7 +76,7 @@ type MockWGIface struct {
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
IsUserspaceBindFunc func() bool IsUserspaceBindFunc func() bool
NameFunc func() string NameFunc func() string
AddressFunc func() device.WGAddress AddressFunc func() wgaddr.Address
ToInterfaceFunc func() *net.Interface ToInterfaceFunc func() *net.Interface
UpFunc func() (*bind.UniversalUDPMuxDefault, error) UpFunc func() (*bind.UniversalUDPMuxDefault, error)
UpdateAddrFunc func(newAddr string) error UpdateAddrFunc func(newAddr string) error
@@ -113,7 +115,7 @@ func (m *MockWGIface) Name() string {
return m.NameFunc() return m.NameFunc()
} }
func (m *MockWGIface) Address() device.WGAddress { func (m *MockWGIface) Address() wgaddr.Address {
return m.AddressFunc() return m.AddressFunc()
} }
@@ -363,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
RemovePeerFunc: func(peerKey string) error { RemovePeerFunc: func(peerKey string) error {
return nil return nil
}, },
AddressFunc: func() iface.WGAddress { AddressFunc: func() wgaddr.Address {
return iface.WGAddress{ return wgaddr.Address{
IP: net.ParseIP("10.20.0.1"), IP: net.ParseIP("10.20.0.1"),
Network: &net.IPNet{ Network: &net.IPNet{
IP: net.ParseIP("10.20.0.0"), IP: net.ParseIP("10.20.0.0"),
@@ -1433,13 +1435,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err) require.NoError(t, err)
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics) accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay) secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil) mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/bind"
"github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/iface/wgproxy"
) )
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
CreateOnAndroid(routeRange []string, ip string, domains []string) error CreateOnAndroid(routeRange []string, ip string, domains []string) error
IsUserspaceBind() bool IsUserspaceBind() bool
Name() string Name() string
Address() device.WGAddress Address() wgaddr.Address
ToInterface() *net.Interface ToInterface() *net.Interface
Up() (*bind.UniversalUDPMuxDefault, error) Up() (*bind.UniversalUDPMuxDefault, error)
UpdateAddr(newAddr string) error UpdateAddr(newAddr string) error

View File

@@ -0,0 +1,107 @@
package ingressgw
import (
"fmt"
"sync"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
type DNATFirewall interface {
AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error)
DeleteDNATRule(rule firewall.Rule) error
}
type RulePair struct {
firewall.ForwardRule
firewall.Rule
}
type Manager struct {
dnatFirewall DNATFirewall
rules map[string]RulePair // keys is the ID of the ForwardRule
rulesMu sync.Mutex
}
func NewManager(dnatFirewall DNATFirewall) *Manager {
return &Manager{
dnatFirewall: dnatFirewall,
rules: make(map[string]RulePair),
}
}
func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
var mErr *multierror.Error
toDelete := make(map[string]RulePair, len(h.rules))
for id, r := range h.rules {
toDelete[id] = r
}
// Process new/updated rules
for _, fwdRule := range forwardRules {
id := fwdRule.ID()
if _, ok := h.rules[id]; ok {
delete(toDelete, id)
continue
}
rule, err := h.dnatFirewall.AddDNATRule(fwdRule)
if err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
continue
}
log.Infof("forward rule has been added '%s'", fwdRule)
h.rules[id] = RulePair{
ForwardRule: fwdRule,
Rule: rule,
}
}
// Remove deleted rules
for id, rulePair := range toDelete {
if err := h.dnatFirewall.DeleteDNATRule(rulePair.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err))
}
log.Infof("forward rule has been deleted '%s'", rulePair.ForwardRule)
delete(h.rules, id)
}
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) Close() error {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
log.Infof("clean up all (%d) forward rules", len(h.rules))
var mErr *multierror.Error
for _, rule := range h.rules {
if err := h.dnatFirewall.DeleteDNATRule(rule.Rule); err != nil {
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err))
}
}
h.rules = make(map[string]RulePair)
return nberrors.FormatErrorOrNil(mErr)
}
func (h *Manager) Rules() []firewall.ForwardRule {
h.rulesMu.Lock()
defer h.rulesMu.Unlock()
rules := make([]firewall.ForwardRule, 0, len(h.rules))
for _, rulePair := range h.rules {
rules = append(rules, rulePair.ForwardRule)
}
return rules
}

View File

@@ -0,0 +1,281 @@
package ingressgw
import (
"fmt"
"net/netip"
"testing"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
)
var (
_ firewall.Rule = (*MocFwRule)(nil)
_ DNATFirewall = &MockDNATFirewall{}
)
type MocFwRule struct {
id string
}
func (m *MocFwRule) ID() string {
return string(m.id)
}
type MockDNATFirewall struct {
throwError bool
}
func (m *MockDNATFirewall) AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error) {
if m.throwError {
return nil, fmt.Errorf("moc error")
}
fwRule := &MocFwRule{
id: fwdRule.ID(),
}
return fwRule, nil
}
func (m *MockDNATFirewall) DeleteDNATRule(rule firewall.Rule) error {
if m.throwError {
return fmt.Errorf("moc error")
}
return nil
}
func (m *MockDNATFirewall) forceToThrowErrors() {
m.throwError = true
}
func TestManager_AddRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
updates := []firewall.ForwardRule{
{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
},
{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}}
if err := mgr.Update(updates); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != len(updates) {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UpdateRule(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
if rules[0].TranslatedAddress.String() != ruleUDP.TranslatedAddress.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].TranslatedPort.String() != ruleUDP.TranslatedPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].DestinationPort.String() != ruleUDP.DestinationPort.String() {
t.Errorf("unexpected rule: %v", rules[0])
}
if rules[0].Protocol != ruleUDP.Protocol {
t.Errorf("unexpected rule: %v", rules[0])
}
}
func TestManager_ExtendRules(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 2 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_UnderlingError(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
ruleUDP := firewall.ForwardRule{
Protocol: firewall.ProtocolUDP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
fw.forceToThrowErrors()
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err == nil {
t.Errorf("expected error")
}
rules := mgr.Rules()
if len(rules) != 1 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_Cleanup(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}
func TestManager_DeleteBrokenRule(t *testing.T) {
fw := &MockDNATFirewall{}
// force to throw errors when Add DNAT Rule
fw.forceToThrowErrors()
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err == nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
// simulate that to remove a broken rule
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
}
func TestManager_Close(t *testing.T) {
fw := &MockDNATFirewall{}
mgr := NewManager(fw)
port, _ := firewall.NewPort(8080)
ruleTCP := firewall.ForwardRule{
Protocol: firewall.ProtocolTCP,
DestinationPort: *port,
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
TranslatedPort: *port,
}
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
t.Errorf("unexpected error: %v", err)
}
if err := mgr.Close(); err != nil {
t.Errorf("unexpected error: %v", err)
}
rules := mgr.Rules()
if len(rules) != 0 {
t.Errorf("unexpected rules count: %d", len(rules))
}
}

View File

@@ -0,0 +1,58 @@
package internal
import (
"errors"
"fmt"
"net"
"net/netip"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
switch protocol {
case mgmProto.RuleProtocol_TCP:
return firewallManager.ProtocolTCP, nil
case mgmProto.RuleProtocol_UDP:
return firewallManager.ProtocolUDP, nil
case mgmProto.RuleProtocol_ICMP:
return firewallManager.ProtocolICMP, nil
case mgmProto.RuleProtocol_ALL:
return firewallManager.ProtocolALL, nil
default:
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
}
}
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
if portInfo == nil {
return nil, errors.New("portInfo cannot be nil")
}
if portInfo.GetPort() != 0 {
return firewallManager.NewPort(int(portInfo.GetPort()))
}
if portInfo.GetRange() != nil {
return firewallManager.NewPort(int(portInfo.GetRange().Start), int(portInfo.GetRange().End))
}
return nil, fmt.Errorf("invalid portInfo: %v", portInfo)
}
func convertToIP(rawIP []byte) (netip.Addr, error) {
if rawIP == nil {
return netip.Addr{}, errors.New("input bytes cannot be nil")
}
if len(rawIP) != net.IPv4len && len(rawIP) != net.IPv6len {
return netip.Addr{}, fmt.Errorf("invalid IP length: %d", len(rawIP))
}
if len(rawIP) == net.IPv4len {
return netip.AddrFrom4([4]byte(rawIP)), nil
}
return netip.AddrFrom16([16]byte(rawIP)), nil
}

View File

@@ -0,0 +1,306 @@
//go:build linux && !android
package conntrack
import (
"encoding/binary"
"fmt"
"net/netip"
"sync"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
nfct "github.com/ti-mo/conntrack"
"github.com/ti-mo/netfilter"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
const defaultChannelSize = 100
// ConnTrack manages kernel-based conntrack events
type ConnTrack struct {
flowLogger nftypes.FlowLogger
iface nftypes.IFaceMapper
conn *nfct.Conn
mux sync.Mutex
instanceID uuid.UUID
started bool
done chan struct{}
sysctlModified bool
}
// New creates a new connection tracker that interfaces with the kernel's conntrack system
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
return &ConnTrack{
flowLogger: flowLogger,
iface: iface,
instanceID: uuid.New(),
started: false,
done: make(chan struct{}, 1),
}
}
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
func (c *ConnTrack) Start(enableCounters bool) error {
c.mux.Lock()
defer c.mux.Unlock()
if c.started {
return nil
}
log.Info("Starting conntrack event listening")
if enableCounters {
c.EnableAccounting()
}
conn, err := nfct.Dial(nil)
if err != nil {
return fmt.Errorf("dial conntrack: %w", err)
}
c.conn = conn
events := make(chan nfct.Event, defaultChannelSize)
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
netfilter.GroupCTNew,
netfilter.GroupCTDestroy,
})
if err != nil {
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
return fmt.Errorf("start conntrack listener: %w", err)
}
c.started = true
go c.receiverRoutine(events, errChan)
return nil
}
func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) {
for {
select {
case event := <-events:
c.handleEvent(event)
case err := <-errChan:
log.Errorf("Error from conntrack event listener: %v", err)
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
return
case <-c.done:
return
}
}
}
// Stop stops the connection tracking. This method is idempotent.
func (c *ConnTrack) Stop() {
c.mux.Lock()
defer c.mux.Unlock()
if !c.started {
return
}
log.Info("Stopping conntrack event listening")
select {
case c.done <- struct{}{}:
default:
}
if c.conn != nil {
if err := c.conn.Close(); err != nil {
log.Errorf("Error closing conntrack connection: %v", err)
}
c.conn = nil
}
c.started = false
c.RestoreAccounting()
}
// Close stops listening for events and cleans up resources
func (c *ConnTrack) Close() error {
c.mux.Lock()
defer c.mux.Unlock()
if c.started {
select {
case c.done <- struct{}{}:
default:
}
}
if c.conn != nil {
err := c.conn.Close()
c.conn = nil
c.started = false
c.RestoreAccounting()
if err != nil {
return fmt.Errorf("close conntrack: %w", err)
}
}
return nil
}
// handleEvent processes incoming conntrack events
func (c *ConnTrack) handleEvent(event nfct.Event) {
if event.Flow == nil {
return
}
if event.Type != nfct.EventNew && event.Type != nfct.EventDestroy {
return
}
flow := *event.Flow
proto := nftypes.Protocol(flow.TupleOrig.Proto.Protocol)
if proto == nftypes.ProtocolUnknown {
return
}
srcIP := flow.TupleOrig.IP.SourceAddress
dstIP := flow.TupleOrig.IP.DestinationAddress
if !c.relevantFlow(srcIP, dstIP) {
return
}
var srcPort, dstPort uint16
var icmpType, icmpCode uint8
switch proto {
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
srcPort = flow.TupleOrig.Proto.SourcePort
dstPort = flow.TupleOrig.Proto.DestinationPort
case nftypes.ICMP:
icmpType = flow.TupleOrig.Proto.ICMPType
icmpCode = flow.TupleOrig.Proto.ICMPCode
}
flowID := c.getFlowID(flow.ID)
direction := c.inferDirection(srcIP, dstIP)
eventType := nftypes.TypeStart
eventStr := "New"
if event.Type == nfct.EventDestroy {
eventType = nftypes.TypeEnd
eventStr = "Ended"
}
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
c.flowLogger.StoreEvent(nftypes.EventFields{
FlowID: flowID,
Type: eventType,
Direction: direction,
Protocol: proto,
SourceIP: srcIP,
DestIP: dstIP,
SourcePort: srcPort,
DestPort: dstPort,
ICMPType: icmpType,
ICMPCode: icmpCode,
RxPackets: c.mapRxPackets(flow, direction),
TxPackets: c.mapTxPackets(flow, direction),
RxBytes: c.mapRxBytes(flow, direction),
TxBytes: c.mapTxBytes(flow, direction),
})
}
// relevantFlow checks if the flow is related to the specified interface
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
// TODO: filter traffic by interface
wgnet := c.iface.Address().Network
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
return false
}
return true
}
// mapRxPackets maps packet counts to RX based on flow direction
func (c *ConnTrack) mapRxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersOrig is from external to us (RX)
// For Egress: CountersReply is from external to us (RX)
if direction == nftypes.Ingress {
return flow.CountersOrig.Packets
}
return flow.CountersReply.Packets
}
// mapTxPackets maps packet counts to TX based on flow direction
func (c *ConnTrack) mapTxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersReply is from us to external (TX)
// For Egress: CountersOrig is from us to external (TX)
if direction == nftypes.Ingress {
return flow.CountersReply.Packets
}
return flow.CountersOrig.Packets
}
// mapRxBytes maps byte counts to RX based on flow direction
func (c *ConnTrack) mapRxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersOrig is from external to us (RX)
// For Egress: CountersReply is from external to us (RX)
if direction == nftypes.Ingress {
return flow.CountersOrig.Bytes
}
return flow.CountersReply.Bytes
}
// mapTxBytes maps byte counts to TX based on flow direction
func (c *ConnTrack) mapTxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
// For Ingress: CountersReply is from us to external (TX)
// For Egress: CountersOrig is from us to external (TX)
if direction == nftypes.Ingress {
return flow.CountersReply.Bytes
}
return flow.CountersOrig.Bytes
}
// getFlowID creates a unique UUID based on the conntrack ID and instance ID
func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
var buf [4]byte
binary.BigEndian.PutUint32(buf[:], conntrackID)
return uuid.NewSHA1(c.instanceID, buf[:])
}
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
wgaddr := c.iface.Address().IP
wgnetwork := c.iface.Address().Network
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
switch {
case wgaddr.Equal(src):
return nftypes.Egress
case wgaddr.Equal(dst):
return nftypes.Ingress
case wgnetwork.Contains(src):
// netbird network -> resource network
return nftypes.Ingress
case wgnetwork.Contains(dst):
// resource network -> netbird network
return nftypes.Egress
// TODO: handle site2site traffic
}
return nftypes.DirectionUnknown
}

View File

@@ -0,0 +1,9 @@
//go:build !linux || android
package conntrack
import nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) nftypes.ConnTracker {
return nil
}

View File

@@ -0,0 +1,73 @@
//go:build linux && !android
package conntrack
import (
"fmt"
"os"
"strconv"
"strings"
log "github.com/sirupsen/logrus"
)
const (
// conntrackAcctPath is the sysctl path for conntrack accounting
conntrackAcctPath = "net.netfilter.nf_conntrack_acct"
)
// EnableAccounting ensures that connection tracking accounting is enabled in the kernel.
func (c *ConnTrack) EnableAccounting() {
// haven't restored yet
if c.sysctlModified {
return
}
modified, err := setSysctl(conntrackAcctPath, 1)
if err != nil {
log.Warnf("Failed to enable conntrack accounting: %v", err)
return
}
c.sysctlModified = modified
}
// RestoreAccounting restores the connection tracking accounting setting to its original value.
func (c *ConnTrack) RestoreAccounting() {
if !c.sysctlModified {
return
}
if _, err := setSysctl(conntrackAcctPath, 0); err != nil {
log.Warnf("Failed to restore conntrack accounting: %v", err)
return
}
c.sysctlModified = false
}
// setSysctl sets a sysctl configuration and returns whether it was modified.
func setSysctl(key string, desiredValue int) (bool, error) {
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
currentValue, err := os.ReadFile(path)
if err != nil {
return false, fmt.Errorf("read sysctl %s: %w", key, err)
}
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
if err != nil && len(currentValue) > 0 {
return false, fmt.Errorf("convert current value to int: %w", err)
}
if currentV == desiredValue {
return false, nil
}
// nolint:gosec
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
return false, fmt.Errorf("write sysctl %s: %w", key, err)
}
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
return true, nil
}

View File

@@ -0,0 +1,125 @@
package logger
import (
"context"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/netflow/store"
"github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
)
type rcvChan chan *types.EventFields
type Logger struct {
mux sync.Mutex
ctx context.Context
cancel context.CancelFunc
enabled atomic.Bool
rcvChan atomic.Pointer[rcvChan]
cancelReceiver context.CancelFunc
statusRecorder *peer.Status
Store types.Store
}
func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
ctx, cancel := context.WithCancel(ctx)
return &Logger{
ctx: ctx,
cancel: cancel,
statusRecorder: statusRecorder,
Store: store.NewMemoryStore(),
}
}
func (l *Logger) StoreEvent(flowEvent types.EventFields) {
if !l.enabled.Load() {
return
}
c := l.rcvChan.Load()
if c == nil {
return
}
select {
case *c <- &flowEvent:
default:
// todo: we should collect or log on this
}
}
func (l *Logger) Enable() {
go l.startReceiver()
}
func (l *Logger) startReceiver() {
if l.enabled.Load() {
return
}
l.mux.Lock()
ctx, cancel := context.WithCancel(l.ctx)
l.cancelReceiver = cancel
l.mux.Unlock()
c := make(rcvChan, 100)
l.rcvChan.Store(&c)
l.enabled.Store(true)
for {
select {
case <-ctx.Done():
log.Info("flow Memory store receiver stopped")
return
case eventFields := <-c:
id := uuid.New()
event := types.Event{
ID: id,
EventFields: *eventFields,
Timestamp: time.Now(),
}
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
event.SourceResourceID = []byte(srcResId)
event.DestResourceID = []byte(dstResId)
l.Store.StoreEvent(&event)
}
}
}
func (l *Logger) Disable() {
l.stop()
l.Store.Close()
}
func (l *Logger) stop() {
if !l.enabled.Load() {
return
}
l.enabled.Store(false)
l.mux.Lock()
if l.cancelReceiver != nil {
l.cancelReceiver()
l.cancelReceiver = nil
}
l.rcvChan.Store(nil)
l.mux.Unlock()
}
func (l *Logger) GetEvents() []*types.Event {
return l.Store.GetEvents()
}
func (l *Logger) DeleteEvents(ids []uuid.UUID) {
l.Store.DeleteEvents(ids)
}
func (l *Logger) Close() {
l.stop()
l.cancel()
}

View File

@@ -0,0 +1,67 @@
package logger_test
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/logger"
"github.com/netbirdio/netbird/client/internal/netflow/types"
)
func TestStore(t *testing.T) {
logger := logger.New(context.Background(), nil)
logger.Enable()
event := types.EventFields{
FlowID: uuid.New(),
Type: types.TypeStart,
Direction: types.Ingress,
Protocol: 6,
}
wait := func() { time.Sleep(time.Millisecond) }
wait()
logger.StoreEvent(event)
wait()
allEvents := logger.GetEvents()
matched := false
for _, e := range allEvents {
if e.EventFields.FlowID == event.FlowID {
matched = true
}
}
if !matched {
t.Errorf("didn't match any event")
}
// test disable
logger.Disable()
wait()
logger.StoreEvent(event)
wait()
allEvents = logger.GetEvents()
if len(allEvents) != 0 {
t.Errorf("expected 0 events, got %d", len(allEvents))
}
// test re-enable
logger.Enable()
wait()
logger.StoreEvent(event)
wait()
allEvents = logger.GetEvents()
matched = false
for _, e := range allEvents {
if e.EventFields.FlowID == event.FlowID {
matched = true
}
}
if !matched {
t.Errorf("didn't match any event")
}
}

View File

@@ -0,0 +1,235 @@
package netflow
import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
"github.com/netbirdio/netbird/client/internal/netflow/logger"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/flow/client"
"github.com/netbirdio/netbird/flow/proto"
)
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
ctx context.Context
receiverClient *client.GRPCClient
publicKey []byte
}
// NewManager creates a new netflow manager
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
flowLogger := logger.New(ctx, statusRecorder)
var ct nftypes.ConnTracker
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
ct = conntrack.New(flowLogger, iface)
}
return &Manager{
logger: flowLogger,
conntrack: ct,
ctx: ctx,
publicKey: publicKey,
}
}
// Update applies new flow configuration settings
// needsNewClient checks if a new client needs to be created
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
current := m.flowConfig
return previous == nil ||
!previous.Enabled ||
previous.TokenPayload != current.TokenPayload ||
previous.TokenSignature != current.TokenSignature ||
previous.URL != current.URL
}
// enableFlow starts components for flow tracking
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
// first make sender ready so events don't pile up
if m.needsNewClient(previous) {
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("error closing previous flow client: %s", err)
}
}
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
if err != nil {
return fmt.Errorf("create client: %w", err)
}
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
m.receiverClient = flowClient
go m.receiveACKs(flowClient)
go m.startSender()
}
m.logger.Enable()
if m.conntrack != nil {
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
return fmt.Errorf("start conntrack: %w", err)
}
}
return nil
}
// disableFlow stops components for flow tracking
func (m *Manager) disableFlow() error {
if m.conntrack != nil {
m.conntrack.Stop()
}
m.logger.Disable()
if m.receiverClient != nil {
return m.receiverClient.Close()
}
return nil
}
// Update applies new flow configuration settings
func (m *Manager) Update(update *nftypes.FlowConfig) error {
if update == nil {
return nil
}
m.mux.Lock()
defer m.mux.Unlock()
previous := m.flowConfig
m.flowConfig = update
if update.Enabled {
return m.enableFlow(previous)
}
return m.disableFlow()
}
// Close cleans up all resources
func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if m.conntrack != nil {
m.conntrack.Close()
}
if m.receiverClient != nil {
if err := m.receiverClient.Close(); err != nil {
log.Warnf("failed to close receiver client: %s", err)
}
}
m.logger.Close()
}
// GetLogger returns the flow logger
func (m *Manager) GetLogger() nftypes.FlowLogger {
return m.logger
}
func (m *Manager) startSender() {
ticker := time.NewTicker(m.flowConfig.Interval)
defer ticker.Stop()
for {
select {
case <-m.ctx.Done():
return
case <-ticker.C:
events := m.logger.GetEvents()
for _, event := range events {
if err := m.send(event); err != nil {
log.Errorf("failed to send flow event to server: %s", err)
continue
}
log.Tracef("sent flow event: %s", event.ID)
}
}
}
}
func (m *Manager) receiveACKs(client *client.GRPCClient) {
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
log.Tracef("received flow event ack: %s", ack.EventId)
m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
return nil
})
if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("failed to receive flow event ack: %s", err)
}
}
func (m *Manager) send(event *nftypes.Event) error {
m.mux.Lock()
client := m.receiverClient
m.mux.Unlock()
if client == nil {
return nil
}
return client.Send(toProtoEvent(m.publicKey, event))
}
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
protoEvent := &proto.FlowEvent{
EventId: event.ID[:],
Timestamp: timestamppb.New(event.Timestamp),
PublicKey: publicKey,
FlowFields: &proto.FlowFields{
FlowId: event.FlowID[:],
RuleId: event.RuleID,
Type: proto.Type(event.Type),
Direction: proto.Direction(event.Direction),
Protocol: uint32(event.Protocol),
SourceIp: event.SourceIP.AsSlice(),
DestIp: event.DestIP.AsSlice(),
RxPackets: event.RxPackets,
TxPackets: event.TxPackets,
RxBytes: event.RxBytes,
TxBytes: event.TxBytes,
SourceResourceId: event.SourceResourceID,
DestResourceId: event.DestResourceID,
},
}
if event.Protocol == nftypes.ICMP {
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
IcmpInfo: &proto.ICMPInfo{
IcmpType: uint32(event.ICMPType),
IcmpCode: uint32(event.ICMPCode),
},
}
return protoEvent
}
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_PortInfo{
PortInfo: &proto.PortInfo{
SourcePort: uint32(event.SourcePort),
DestPort: uint32(event.DestPort),
},
}
return protoEvent
}

Some files were not shown because too many files have changed in this diff Show More