mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Compare commits
110 Commits
v0.37.2
...
test/netwo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d2c774378 | ||
|
|
ab2e3fec72 | ||
|
|
47f88f7057 | ||
|
|
ee33a6ed7c | ||
|
|
da662cfd08 | ||
|
|
ed2ee1ee9d | ||
|
|
76d73548d6 | ||
|
|
11828a064a | ||
|
|
0c2a3dd937 | ||
|
|
cd9eff5331 | ||
|
|
47dcf8d68c | ||
|
|
80ceb80197 | ||
|
|
cc8f6bcaf3 | ||
|
|
636a0e2475 | ||
|
|
e66e329bf6 | ||
|
|
aaa23beeec | ||
|
|
6bef474e9e | ||
|
|
81040ff80a | ||
|
|
c73481aee4 | ||
|
|
92286b2541 | ||
|
|
d8bcf745b0 | ||
|
|
8430139d80 | ||
|
|
a2962b4ce0 | ||
|
|
16fffdb75b | ||
|
|
036cecbf46 | ||
|
|
3482852bb6 | ||
|
|
fd62665b1f | ||
|
|
fc1da94520 | ||
|
|
1ffe48f0d4 | ||
|
|
a3b8a21385 | ||
|
|
86492b88c4 | ||
|
|
d08a629f9e | ||
|
|
36da464413 | ||
|
|
268e3404d3 | ||
|
|
54d0591833 | ||
|
|
86370a0e7b | ||
|
|
cb16d0f45f | ||
|
|
e8d8bd8f18 | ||
|
|
8b07f21c28 | ||
|
|
54be772ffd | ||
|
|
3c3a454e61 | ||
|
|
5ff77b3595 | ||
|
|
b180edbe5c | ||
|
|
de3b5c78d7 | ||
|
|
0b42f40cf6 | ||
|
|
0a042ac36d | ||
|
|
e7f921d787 | ||
|
|
e9f11fb11b | ||
|
|
419ed275fa | ||
|
|
2d4fcaf186 | ||
|
|
acf172b52c | ||
|
|
8c81a823fa | ||
|
|
619c549547 | ||
|
|
9a713a0987 | ||
|
|
c4945cd565 | ||
|
|
1e10c17ecb | ||
|
|
96d5190436 | ||
|
|
d19c26df06 | ||
|
|
36e36414d9 | ||
|
|
7e69589e05 | ||
|
|
aa613ab79a | ||
|
|
6ead0ff95e | ||
|
|
0db65a8984 | ||
|
|
c138807e95 | ||
|
|
637c0c8949 | ||
|
|
c72e13d8e6 | ||
|
|
f6d7bccfa0 | ||
|
|
e3ed01cafb | ||
|
|
fa748a7ec2 | ||
|
|
cccc615783 | ||
|
|
2021463ca0 | ||
|
|
f48cfd52e9 | ||
|
|
6838f53f40 | ||
|
|
8276236dfa | ||
|
|
994b923d56 | ||
|
|
59e2432231 | ||
|
|
eee0d123e4 | ||
|
|
e943203ae2 | ||
|
|
6a775217cf | ||
|
|
175674749f | ||
|
|
1e534cecf6 | ||
|
|
aa3aa8c6a8 | ||
|
|
fbdfe45c25 | ||
|
|
81ee172db8 | ||
|
|
f8fd65a65f | ||
|
|
62b978c050 | ||
|
|
4ebf1410c6 | ||
|
|
630edf2480 | ||
|
|
ea469d28d7 | ||
|
|
597f1d47b8 | ||
|
|
fcc96417f9 | ||
|
|
8755211a60 | ||
|
|
e6d4653b08 | ||
|
|
eb69f2de78 | ||
|
|
206420c085 | ||
|
|
88a864c195 | ||
|
|
a789e9e6d8 | ||
|
|
9930913e4e | ||
|
|
48675f579f | ||
|
|
afec455f86 | ||
|
|
035c5d9f23 | ||
|
|
b2a5b29fb2 | ||
|
|
9ec61206c2 | ||
|
|
1b011a2d85 | ||
|
|
a85ea1ddb0 | ||
|
|
829e40d2aa | ||
|
|
6344e34880 | ||
|
|
a76ca8c565 | ||
|
|
26693e4ea8 | ||
|
|
f6a71f4193 |
10
.github/workflows/golang-test-linux.yml
vendored
10
.github/workflows/golang-test-linux.yml
vendored
@@ -258,7 +258,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
@@ -325,8 +325,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres', 'mysql' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
@@ -392,7 +392,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
@@ -461,7 +461,7 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ 'amd64' ]
|
||||||
store: [ 'sqlite', 'postgres']
|
store: [ 'sqlite', 'postgres']
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
98
client/cmd/forwarding_rules.go
Normal file
98
client/cmd/forwarding_rules.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var forwardingRulesCmd = &cobra.Command{
|
||||||
|
Use: "forwarding",
|
||||||
|
Short: "List forwarding rules",
|
||||||
|
Long: `Commands to list forwarding rules.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var forwardingRulesListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List forwarding rules",
|
||||||
|
Example: " netbird forwarding list",
|
||||||
|
Long: "Commands to list forwarding rules.",
|
||||||
|
RunE: listForwardingRules,
|
||||||
|
}
|
||||||
|
|
||||||
|
func listForwardingRules(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.GetRules()) == 0 {
|
||||||
|
cmd.Println("No forwarding rules available.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
printForwardingRules(cmd, resp.GetRules())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||||
|
cmd.Println("Available forwarding rules:")
|
||||||
|
|
||||||
|
// Sort rules by translated address
|
||||||
|
sort.Slice(rules, func(i, j int) bool {
|
||||||
|
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
|
||||||
|
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
|
||||||
|
}
|
||||||
|
if rules[i].GetProtocol() != rules[j].GetProtocol() {
|
||||||
|
return rules[i].GetProtocol() < rules[j].GetProtocol()
|
||||||
|
}
|
||||||
|
|
||||||
|
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
|
||||||
|
})
|
||||||
|
|
||||||
|
var lastIP string
|
||||||
|
for _, rule := range rules {
|
||||||
|
dPort := portToString(rule.GetDestinationPort())
|
||||||
|
tPort := portToString(rule.GetTranslatedPort())
|
||||||
|
if lastIP != rule.GetTranslatedAddress() {
|
||||||
|
lastIP = rule.GetTranslatedAddress()
|
||||||
|
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFirstPort(portInfo *proto.PortInfo) int {
|
||||||
|
switch v := portInfo.PortSelection.(type) {
|
||||||
|
case *proto.PortInfo_Port:
|
||||||
|
return int(v.Port)
|
||||||
|
case *proto.PortInfo_Range_:
|
||||||
|
return int(v.Range.GetStart())
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func portToString(translatedPort *proto.PortInfo) string {
|
||||||
|
switch v := translatedPort.PortSelection.(type) {
|
||||||
|
case *proto.PortInfo_Port:
|
||||||
|
return fmt.Sprintf("%d", v.Port)
|
||||||
|
case *proto.PortInfo_Range_:
|
||||||
|
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
|
||||||
|
default:
|
||||||
|
return "No port specified"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -145,6 +145,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
rootCmd.AddCommand(networksCMD)
|
rootCmd.AddCommand(networksCMD)
|
||||||
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||||
@@ -153,6 +154,8 @@ func init() {
|
|||||||
networksCMD.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
|
|
||||||
|
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
|
||||||
|
|
||||||
debugCmd.AddCommand(debugBundleCmd)
|
debugCmd.AddCommand(debugBundleCmd)
|
||||||
debugCmd.AddCommand(logCmd)
|
debugCmd.AddCommand(logCmd)
|
||||||
logCmd.AddCommand(logLevelCmd)
|
logCmd.AddCommand(logLevelCmd)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -89,13 +90,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -134,10 +134,11 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
run := make(chan error, 1)
|
run := make(chan struct{}, 1)
|
||||||
|
clientErr := make(chan error, 1)
|
||||||
go func() {
|
go func() {
|
||||||
if err := client.Run(run); err != nil {
|
if err := client.Run(run); err != nil {
|
||||||
run <- err
|
clientErr <- err
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -147,13 +148,9 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||||
}
|
}
|
||||||
return startCtx.Err()
|
return startCtx.Err()
|
||||||
case err := <-run:
|
case err := <-clientErr:
|
||||||
if err != nil {
|
|
||||||
if stopErr := client.Stop(); stopErr != nil {
|
|
||||||
return fmt.Errorf("stop error after failed to startup. Stop error: %w. Start error: %w", stopErr, err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("startup: %w", err)
|
return fmt.Errorf("startup: %w", err)
|
||||||
}
|
case <-run:
|
||||||
}
|
}
|
||||||
|
|
||||||
c.connect = client
|
c.connect = client
|
||||||
|
|||||||
@@ -10,17 +10,18 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ import (
|
|||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() device.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -32,8 +32,6 @@ type entry struct {
|
|||||||
type aclManager struct {
|
type aclManager struct {
|
||||||
iptablesClient *iptables.IPTables
|
iptablesClient *iptables.IPTables
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
|
||||||
|
|
||||||
entries aclEntries
|
entries aclEntries
|
||||||
optionalEntries map[string][]entry
|
optionalEntries map[string][]entry
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
@@ -41,12 +39,10 @@ type aclManager struct {
|
|||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||||
m := &aclManager{
|
m := &aclManager{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
routingFwChainName: routingFwChainName,
|
|
||||||
|
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
@@ -79,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
@@ -314,9 +311,12 @@ func (m *aclManager) seedInitialEntries() {
|
|||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||||
|
|
||||||
|
// Inbound is handled by our ACLs, the rest is dropped.
|
||||||
|
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
|
||||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ type Manager struct {
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return nil, fmt.Errorf("create router: %w", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
_ string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
@@ -226,6 +227,22 @@ func (m *Manager) DisableRouting() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,15 +10,15 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ifaceMock = &iFaceMock{
|
var ifaceMock = &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
func (i *iFaceMock) Name() string {
|
||||||
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
|
|||||||
panic("NameFunc is not set")
|
panic("NameFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
func (i *iFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc != nil {
|
if i.AddressFunc != nil {
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -97,7 +97,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -117,8 +117,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -184,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@@ -26,19 +27,33 @@ const (
|
|||||||
tableFilter = "filter"
|
tableFilter = "filter"
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
tableMangle = "mangle"
|
tableMangle = "mangle"
|
||||||
|
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
chainPREROUTING = "PREROUTING"
|
chainPREROUTING = "PREROUTING"
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
chainRTFWD = "NETBIRD-RT-FWD"
|
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||||
|
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||||
chainRTPRE = "NETBIRD-RT-PRE"
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
|
chainRTRDR = "NETBIRD-RT-RDR"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
jumpPre = "jump-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNat = "jump-nat"
|
jumpNatPre = "jump-nat-pre"
|
||||||
|
jumpNatPost = "jump-nat-post"
|
||||||
matchSet = "--match-set"
|
matchSet = "--match-set"
|
||||||
|
|
||||||
|
dnatSuffix = "_dnat"
|
||||||
|
snatSuffix = "_snat"
|
||||||
|
fwdSuffix = "_fwd"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ruleInfo struct {
|
||||||
|
chain string
|
||||||
|
table string
|
||||||
|
rule []string
|
||||||
|
}
|
||||||
|
|
||||||
type routeFilteringRuleParams struct {
|
type routeFilteringRuleParams struct {
|
||||||
Sources []netip.Prefix
|
Sources []netip.Prefix
|
||||||
Destination netip.Prefix
|
Destination netip.Prefix
|
||||||
@@ -62,6 +77,7 @@ type router struct {
|
|||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||||
@@ -69,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
|
|||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -104,6 +121,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -111,7 +129,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
@@ -139,9 +157,9 @@ func (r *router) AddRouteFiltering(
|
|||||||
var err error
|
var err error
|
||||||
if action == firewall.ActionDrop {
|
if action == firewall.ActionDrop {
|
||||||
// after the established rule
|
// after the established rule
|
||||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...)
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||||
} else {
|
} else {
|
||||||
err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...)
|
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -156,12 +174,12 @@ func (r *router) AddRouteFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
ruleKey := rule.GetRuleID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
setName := r.findSetNameInRule(rule)
|
setName := r.findSetNameInRule(rule)
|
||||||
|
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("delete route rule: %v", err)
|
return fmt.Errorf("delete route rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
@@ -212,6 +230,10 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
|
|
||||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if r.legacyManagement {
|
if r.legacyManagement {
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
@@ -238,6 +260,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -264,7 +290,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -277,7 +303,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
@@ -305,7 +331,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
delete(r.rules, k)
|
delete(r.rules, k)
|
||||||
@@ -343,9 +369,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
}{
|
}{
|
||||||
{chainRTFWD, tableFilter},
|
{chainRTFWDIN, tableFilter},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTFWDOUT, tableFilter},
|
||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRDR, tableNat},
|
||||||
} {
|
} {
|
||||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -365,16 +393,22 @@ func (r *router) createContainers() error {
|
|||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
}{
|
}{
|
||||||
{chainRTFWD, tableFilter},
|
{chainRTFWDIN, tableFilter},
|
||||||
|
{chainRTFWDOUT, tableFilter},
|
||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRDR, tableNat},
|
||||||
} {
|
} {
|
||||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
|
||||||
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
|
||||||
return fmt.Errorf("insert established rule: %w", err)
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -415,27 +449,6 @@ func (r *router) addPostroutingRules() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createAndSetupChain(chain string) error {
|
|
||||||
table := r.getTableForChain(chain)
|
|
||||||
|
|
||||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
|
||||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) getTableForChain(chain string) string {
|
|
||||||
switch chain {
|
|
||||||
case chainRTNAT:
|
|
||||||
return tableNat
|
|
||||||
case chainRTPRE:
|
|
||||||
return tableMangle
|
|
||||||
default:
|
|
||||||
return tableFilter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) insertEstablishedRule(chain string) error {
|
func (r *router) insertEstablishedRule(chain string) error {
|
||||||
establishedRule := getConntrackEstablished()
|
establishedRule := getConntrackEstablished()
|
||||||
|
|
||||||
@@ -454,28 +467,43 @@ func (r *router) addJumpRules() error {
|
|||||||
// Jump to NAT chain
|
// Jump to NAT chain
|
||||||
natRule := []string{"-j", chainRTNAT}
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return fmt.Errorf("add nat jump rule: %v", err)
|
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[jumpNat] = natRule
|
r.rules[jumpNatPost] = natRule
|
||||||
|
|
||||||
// Jump to prerouting chain
|
// Jump to mangle prerouting chain
|
||||||
preRule := []string{"-j", chainRTPRE}
|
preRule := []string{"-j", chainRTPRE}
|
||||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[jumpPre] = preRule
|
r.rules[jumpManglePre] = preRule
|
||||||
|
|
||||||
|
// Jump to nat prerouting chain
|
||||||
|
rdrRule := []string{"-j", chainRTRDR}
|
||||||
|
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
|
||||||
|
return fmt.Errorf("add nat prerouting jump rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpNatPre] = rdrRule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanJumpRules() error {
|
func (r *router) cleanJumpRules() error {
|
||||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
table := tableNat
|
var table, chain string
|
||||||
chain := chainPOSTROUTING
|
switch ruleKey {
|
||||||
if ruleKey == jumpPre {
|
case jumpNatPost:
|
||||||
|
table = tableNat
|
||||||
|
chain = chainPOSTROUTING
|
||||||
|
case jumpManglePre:
|
||||||
table = tableMangle
|
table = tableMangle
|
||||||
chain = chainPREROUTING
|
chain = chainPREROUTING
|
||||||
|
case jumpNatPre:
|
||||||
|
table = tableNat
|
||||||
|
chain = chainPREROUTING
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||||
@@ -520,6 +548,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = rule
|
r.rules[ruleKey] = rule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -535,6 +565,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
log.Debugf("marking rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -564,6 +595,137 @@ func (r *router) updateState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
toDestination := rule.TranslatedAddress.String()
|
||||||
|
switch {
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
// no translated port, use original port
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
// need the "/originalport" suffix to avoid dnat port randomization
|
||||||
|
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := strings.ToLower(string(rule.Protocol))
|
||||||
|
|
||||||
|
rules := make(map[string]ruleInfo, 3)
|
||||||
|
|
||||||
|
// DNAT rule
|
||||||
|
dnatRule := []string{
|
||||||
|
"!", "-i", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-j", "DNAT",
|
||||||
|
"--to-destination", toDestination,
|
||||||
|
}
|
||||||
|
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||||
|
rules[ruleKey+dnatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTRDR,
|
||||||
|
rule: dnatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SNAT rule
|
||||||
|
snatRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "MASQUERADE",
|
||||||
|
}
|
||||||
|
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleKey+snatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTNAT,
|
||||||
|
rule: snatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward filtering rule, if fwd policy is DROP
|
||||||
|
forwardRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
}
|
||||||
|
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleKey+fwdSuffix] = ruleInfo{
|
||||||
|
table: tableFilter,
|
||||||
|
chain: chainRTFWDOUT,
|
||||||
|
rule: forwardRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||||
|
log.Errorf("rollback failed: %v", rollbackErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||||
|
}
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||||
|
// On rollback error, add to rules map for next cleanup
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr != nil {
|
||||||
|
r.updateState()
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+fwdSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
var rule []string
|
var rule []string
|
||||||
|
|
||||||
|
|||||||
@@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Now 5 rules:
|
// Now 5 rules:
|
||||||
// 1. established rule in forward chain
|
// 1. established rule forward in
|
||||||
// 2. jump rule to NAT chain
|
// 2. estbalished rule forward out
|
||||||
// 3. jump rule to PRE chain
|
// 3. jump rule to POST nat chain
|
||||||
// 4. static outbound masquerade rule
|
// 4. jump rule to PRE mangle chain
|
||||||
// 5. static return masquerade rule
|
// 5. jump rule to PRE nat chain
|
||||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
// 6. static outbound masquerade rule
|
||||||
|
// 7. static return masquerade rule
|
||||||
|
require.Len(t, manager.rules, 7, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
@@ -328,18 +330,18 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
// Log the internal rule
|
// Log the internal rule
|
||||||
t.Logf("Internal rule: %v", rule)
|
t.Logf("Internal rule: %v", rule)
|
||||||
|
|
||||||
// Check if the rule exists in iptables
|
// Check if the rule exists in iptables
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,6 @@ type Rule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) ID() string {
|
||||||
return r.ruleID
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,13 +4,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress iface.WGAddress `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -18,7 +17,7 @@ func (i *InterfaceState) Name() string {
|
|||||||
return i.NameStr
|
return i.NameStr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Address() device.WGAddress {
|
func (i *InterfaceState) Address() wgaddr.Address {
|
||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ const (
|
|||||||
// Each firewall type for different OS can use different type
|
// Each firewall type for different OS can use different type
|
||||||
// of the properties to hold data of the created rule
|
// of the properties to hold data of the created rule
|
||||||
type Rule interface {
|
type Rule interface {
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
GetRuleID() string
|
ID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RuleDirection is the traffic direction which a rule is applied
|
// RuleDirection is the traffic direction which a rule is applied
|
||||||
@@ -65,13 +65,13 @@ type Manager interface {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
AddPeerFiltering(
|
AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]Rule, error)
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -80,7 +80,15 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto Protocol,
|
||||||
|
sPort *Port,
|
||||||
|
dPort *Port,
|
||||||
|
action Action,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule
|
// DeleteRouteRule deletes a routing rule
|
||||||
DeleteRouteRule(rule Rule) error
|
DeleteRouteRule(rule Rule) error
|
||||||
@@ -105,6 +113,12 @@ type Manager interface {
|
|||||||
EnableRouting() error
|
EnableRouting() error
|
||||||
|
|
||||||
DisableRouting() error
|
DisableRouting() error
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
AddDNATRule(ForwardRule) (Rule, error)
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
DeleteDNATRule(Rule) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
|||||||
27
client/firewall/manager/forward_rule.go
Normal file
27
client/firewall/manager/forward_rule.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ForwardRule todo figure out better place to this to avoid circular imports
|
||||||
|
type ForwardRule struct {
|
||||||
|
Protocol Protocol
|
||||||
|
DestinationPort Port
|
||||||
|
TranslatedAddress netip.Addr
|
||||||
|
TranslatedPort Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ForwardRule) ID() string {
|
||||||
|
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||||
|
r.Protocol,
|
||||||
|
r.DestinationPort.String(),
|
||||||
|
r.TranslatedAddress.String(),
|
||||||
|
r.TranslatedPort.String())
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ForwardRule) String() string {
|
||||||
|
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
|
||||||
|
}
|
||||||
@@ -1,30 +1,12 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Protocol is the protocol of the port
|
|
||||||
type Protocol string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ProtocolTCP is the TCP protocol
|
|
||||||
ProtocolTCP Protocol = "tcp"
|
|
||||||
|
|
||||||
// ProtocolUDP is the UDP protocol
|
|
||||||
ProtocolUDP Protocol = "udp"
|
|
||||||
|
|
||||||
// ProtocolICMP is the ICMP protocol
|
|
||||||
ProtocolICMP Protocol = "icmp"
|
|
||||||
|
|
||||||
// ProtocolALL cover all supported protocols
|
|
||||||
ProtocolALL Protocol = "all"
|
|
||||||
|
|
||||||
// ProtocolUnknown unknown protocol
|
|
||||||
ProtocolUnknown Protocol = "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Port of the address for firewall rule
|
// Port of the address for firewall rule
|
||||||
|
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||||
type Port struct {
|
type Port struct {
|
||||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||||
IsRange bool
|
IsRange bool
|
||||||
@@ -33,6 +15,25 @@ type Port struct {
|
|||||||
Values []uint16
|
Values []uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewPort(ports ...int) (*Port, error) {
|
||||||
|
if len(ports) == 0 {
|
||||||
|
return nil, fmt.Errorf("no port provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
ports16 := make([]uint16, len(ports))
|
||||||
|
for i, port := range ports {
|
||||||
|
if port < 1 || port > 65535 {
|
||||||
|
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
|
||||||
|
}
|
||||||
|
ports16[i] = uint16(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Port{
|
||||||
|
IsRange: len(ports) > 1,
|
||||||
|
Values: ports16,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// String interface implementation
|
// String interface implementation
|
||||||
func (p *Port) String() string {
|
func (p *Port) String() string {
|
||||||
var ports string
|
var ports string
|
||||||
|
|||||||
19
client/firewall/manager/protocol.go
Normal file
19
client/firewall/manager/protocol.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
// Protocol is the protocol of the port
|
||||||
|
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||||
|
type Protocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProtocolTCP is the TCP protocol
|
||||||
|
ProtocolTCP Protocol = "tcp"
|
||||||
|
|
||||||
|
// ProtocolUDP is the UDP protocol
|
||||||
|
ProtocolUDP Protocol = "udp"
|
||||||
|
|
||||||
|
// ProtocolICMP is the ICMP protocol
|
||||||
|
ProtocolICMP Protocol = "icmp"
|
||||||
|
|
||||||
|
// ProtocolALL cover all supported protocols
|
||||||
|
ProtocolALL Protocol = "all"
|
||||||
|
)
|
||||||
@@ -84,13 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *AclManager) AddPeerFiltering(
|
func (m *AclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var ipset *nftables.Set
|
var ipset *nftables.Set
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
@@ -102,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -127,7 +127,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.ID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
log.Errorf("failed to delete mangle rule: %v", err)
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.ID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,7 +176,7 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.ID())
|
||||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||||
|
|
||||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||||
@@ -256,7 +256,6 @@ func (m *AclManager) addIOFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipset *nftables.Set,
|
ipset *nftables.Set,
|
||||||
comment string,
|
|
||||||
) (*Rule, error) {
|
) (*Rule, error) {
|
||||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
@@ -338,7 +337,7 @@ func (m *AclManager) addIOFiltering(
|
|||||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
userData := []byte(ruleId)
|
||||||
|
|
||||||
chain := m.chainInputRules
|
chain := m.chainInputRules
|
||||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ const (
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -342,6 +343,22 @@ func (m *Manager) Flush() error {
|
|||||||
return m.aclManager.Flush()
|
return m.aclManager.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -16,15 +16,15 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ifaceMock = &iFaceMock{
|
var ifaceMock = &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
@@ -37,7 +37,7 @@ var ifaceMock = &iFaceMock{
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
func (i *iFaceMock) Name() string {
|
||||||
@@ -47,7 +47,7 @@ func (i *iFaceMock) Name() string {
|
|||||||
panic("NameFunc is not set")
|
panic("NameFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
func (i *iFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc != nil {
|
if i.AddressFunc != nil {
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -171,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -283,10 +283,11 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := net.ParseIP("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
|
|||||||
@@ -14,23 +14,31 @@ import (
|
|||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/google/nftables/xt"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
tableNat = "nat"
|
||||||
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
|
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
|
|
||||||
|
dnatSuffix = "_dnat"
|
||||||
|
snatSuffix = "_snat"
|
||||||
)
|
)
|
||||||
|
|
||||||
const refreshRulesMapError = "refresh rules map: %w"
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
@@ -49,6 +57,7 @@ type router struct {
|
|||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -59,6 +68,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
|||||||
chains: make(map[string]*nftables.Chain),
|
chains: make(map[string]*nftables.Chain),
|
||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -98,7 +108,52 @@ func (r *router) Reset() error {
|
|||||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||||
r.ipsetCounter.Clear()
|
r.ipsetCounter.Clear()
|
||||||
|
|
||||||
return r.removeAcceptForwardRules()
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeAcceptForwardRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeNatPreroutingRules() error {
|
||||||
|
table := &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
}
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: table,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
}
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from nat table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Delete rules that have our UserData suffix
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
@@ -133,14 +188,22 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingRdr,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
// Chain is created by acl manager
|
// Chain is created by acl manager
|
||||||
// TODO: move creation to a common place
|
// TODO: move creation to a common place
|
||||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||||
Name: chainNamePrerouting,
|
Name: chainNamePrerouting,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
// Add the single NAT rule that matches on mark
|
||||||
@@ -165,6 +228,7 @@ func (r *router) createContainers() error {
|
|||||||
|
|
||||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -173,7 +237,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
@@ -281,7 +345,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleKey := rule.GetRuleID()
|
ruleKey := rule.ID()
|
||||||
nftRule, exists := r.rules[ruleKey]
|
nftRule, exists := r.rules[ruleKey]
|
||||||
if !exists {
|
if !exists {
|
||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
@@ -410,6 +474,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
|||||||
|
|
||||||
// AddNatRule appends a nftables rule pair to the nat chain
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -836,6 +904,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
|
|
||||||
// RemoveNatRule removes the prerouting mark rule
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -896,6 +968,269 @@ func (r *router) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
protoNum, err := protoToInt(rule.Protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.addDnatMasq(rule, protoNum, ruleKey)
|
||||||
|
|
||||||
|
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||||
|
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||||
|
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||||
|
// TODO: find chains with drop policies and add rules there
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
|
||||||
|
dnatExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
|
||||||
|
|
||||||
|
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||||
|
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||||
|
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||||
|
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||||
|
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||||
|
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeDestNAT,
|
||||||
|
Family: uint32(nftables.TableFamilyIPv4),
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegProtoMin: regProtoMin,
|
||||||
|
RegProtoMax: regProtoMax,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingRdr],
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleKey + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
switch {
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
return r.handlePortRange(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
return r.handleAddressOnly(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
return r.handleSinglePort(rule)
|
||||||
|
default:
|
||||||
|
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 3,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 3, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 0, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Target{
|
||||||
|
Name: "DNAT",
|
||||||
|
Rev: 2,
|
||||||
|
Info: &xt.NatRange2{
|
||||||
|
NatRange: xt.NatRange{
|
||||||
|
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||||
|
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MinPort: rule.TranslatedPort.Values[0],
|
||||||
|
MaxPort: rule.TranslatedPort.Values[1],
|
||||||
|
},
|
||||||
|
BasePort: rule.DestinationPort.Values[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
},
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: r.filterTable,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
},
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleKey + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
|
||||||
|
masqExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 16,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
|
||||||
|
masqExprs = append(masqExprs, &expr.Masq{})
|
||||||
|
|
||||||
|
masqRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: masqExprs,
|
||||||
|
UserData: []byte(ruleKey + snatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(masqRule)
|
||||||
|
r.rules[ruleKey+snatSuffix] = masqRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
|
if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if merr == nil {
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||||
var offset uint32
|
var offset uint32
|
||||||
@@ -959,15 +1294,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
if port.IsRange && len(port.Values) == 2 {
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
// Handle port range
|
// Handle port range
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Cmp{
|
&expr.Range{
|
||||||
Op: expr.CmpOpGte,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||||
},
|
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpLte,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
t.Log("Internal rule expressions:")
|
t.Log("Internal rule expressions:")
|
||||||
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
var nftRule *nftables.Rule
|
var nftRule *nftables.Rule
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
if string(rule.UserData) == ruleKey.ID() {
|
||||||
nftRule = rule
|
nftRule = rule
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
|||||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||||
payloadFound = true
|
payloadFound = true
|
||||||
}
|
}
|
||||||
case *expr.Cmp:
|
case *expr.Range:
|
||||||
if port.IsRange {
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
fromPort := binary.BigEndian.Uint16(ex.FromData)
|
||||||
|
toPort := binary.BigEndian.Uint16(ex.ToData)
|
||||||
|
if fromPort == port.Values[0] && toPort == port.Values[1] {
|
||||||
portMatchFound = true
|
portMatchFound = true
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
|
case *expr.Cmp:
|
||||||
|
if !port.IsRange {
|
||||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||||
for _, p := range port.Values {
|
for _, p := range port.Values {
|
||||||
if uint16(p) == portValue {
|
if p == portValue {
|
||||||
portMatchFound = true
|
portMatchFound = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,6 +16,6 @@ type Rule struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) ID() string {
|
||||||
return r.ruleID
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,13 +3,12 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress iface.WGAddress `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -17,7 +16,7 @@ func (i *InterfaceState) Name() string {
|
|||||||
return i.NameStr
|
return i.NameStr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Address() device.WGAddress {
|
func (i *InterfaceState) Address() wgaddr.Address {
|
||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -16,8 +17,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@@ -31,8 +32,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,28 +22,31 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Close closes the firewall manager
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ package common
|
|||||||
import (
|
import (
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
GetWGDevice() *wgdevice.Device
|
GetWGDevice() *wgdevice.Device
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,20 +1,27 @@
|
|||||||
// common.go
|
|
||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"fmt"
|
||||||
"sync"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseConnTrack provides common fields and locking for all connection types
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
type BaseConnTrack struct {
|
type BaseConnTrack struct {
|
||||||
SourceIP net.IP
|
FlowId uuid.UUID
|
||||||
DestIP net.IP
|
Direction nftypes.Direction
|
||||||
SourcePort uint16
|
SourceIP netip.Addr
|
||||||
DestPort uint16
|
DestIP netip.Addr
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64
|
||||||
|
PacketsTx atomic.Uint64
|
||||||
|
PacketsRx atomic.Uint64
|
||||||
|
BytesTx atomic.Uint64
|
||||||
|
BytesRx atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
|||||||
b.lastSeen.Store(time.Now().UnixNano())
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateCounters safely updates the packet and byte counters
|
||||||
|
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||||
|
if direction == nftypes.Egress {
|
||||||
|
b.PacketsTx.Add(1)
|
||||||
|
b.BytesTx.Add(uint64(bytes))
|
||||||
|
} else {
|
||||||
|
b.PacketsRx.Add(1)
|
||||||
|
b.BytesRx.Add(uint64(bytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetLastSeen safely gets the last seen timestamp
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
return time.Unix(0, b.lastSeen.Load())
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
|||||||
return time.Since(lastSeen) > timeout
|
return time.Since(lastSeen) > timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPAddr is a fixed-size IP address to avoid allocations
|
|
||||||
type IPAddr [16]byte
|
|
||||||
|
|
||||||
// MakeIPAddr creates an IPAddr from net.IP
|
|
||||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
|
||||||
// Optimization: check for v4 first as it's more common
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
copy(addr[12:], ip4)
|
|
||||||
} else {
|
|
||||||
copy(addr[:], ip.To16())
|
|
||||||
}
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnKey uniquely identifies a connection
|
// ConnKey uniquely identifies a connection
|
||||||
type ConnKey struct {
|
type ConnKey struct {
|
||||||
SrcIP IPAddr
|
SrcIP netip.Addr
|
||||||
DstIP IPAddr
|
DstIP netip.Addr
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeConnKey creates a connection key
|
func (c ConnKey) String() string {
|
||||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
return ConnKey{
|
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
|
||||||
DstIP: MakeIPAddr(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateIPs checks if IPs match without allocation
|
|
||||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
|
||||||
if ip4 := pktIP.To4(); ip4 != nil {
|
|
||||||
// Compare IPv4 addresses (last 4 bytes)
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
if connIP[12+i] != ip4[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Compare full IPv6 addresses
|
|
||||||
ip6 := pktIP.To16()
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
if connIP[i] != ip6[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
|
||||||
type PreallocatedIPs struct {
|
|
||||||
sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPreallocatedIPs creates a new IP pool
|
|
||||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
|
||||||
return &PreallocatedIPs{
|
|
||||||
Pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
ip := make(net.IP, 16)
|
|
||||||
return &ip
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get retrieves an IP from the pool
|
|
||||||
func (p *PreallocatedIPs) Get() net.IP {
|
|
||||||
return *p.Pool.Get().(*net.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put returns an IP to the pool
|
|
||||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
|
||||||
p.Pool.Put(&ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyIP copies an IP address efficiently
|
|
||||||
func copyIP(dst, src net.IP) {
|
|
||||||
if len(src) == 16 {
|
|
||||||
copy(dst, src)
|
|
||||||
} else {
|
|
||||||
// Handle IPv4
|
|
||||||
copy(dst[12:], src.To4())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,94 +1,67 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = MakeIPAddr(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ValidateIPs", func(b *testing.B) {
|
|
||||||
ip1 := net.ParseIP("192.168.1.1")
|
|
||||||
ip2 := net.ParseIP("192.168.1.1")
|
|
||||||
addr := MakeIPAddr(ip1)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = ValidateIPs(addr, ip2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IPPool", func(b *testing.B) {
|
|
||||||
pool := NewPreallocatedIPs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
ip := pool.Get()
|
|
||||||
pool.Put(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -19,18 +23,20 @@ const (
|
|||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
type ICMPConnKey struct {
|
type ICMPConnKey struct {
|
||||||
// Supports both IPv4 and IPv6
|
SrcIP netip.Addr
|
||||||
SrcIP [16]byte
|
DstIP netip.Addr
|
||||||
DstIP [16]byte
|
ID uint16
|
||||||
Sequence uint16 // ICMP sequence number
|
}
|
||||||
ID uint16 // ICMP identifier
|
|
||||||
|
func (i ICMPConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
type ICMPConnTrack struct {
|
type ICMPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
Sequence uint16
|
ICMPType uint8
|
||||||
ID uint16
|
ICMPCode uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
@@ -39,131 +45,201 @@ type ICMPTracker struct {
|
|||||||
connections map[ICMPConnKey]*ICMPConnTrack
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
done chan struct{}
|
flowLogger nftypes.FlowLogger
|
||||||
ipPool *PreallocatedIPs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
tracker := &ICMPTracker{
|
tracker := &ICMPTracker{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP Echo Request
|
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
key := ICMPConnKey{
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
t.mutex.Lock()
|
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &ICMPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
},
|
|
||||||
ID: id,
|
ID: id,
|
||||||
Sequence: seq,
|
|
||||||
}
|
}
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New ICMP connection %v", key)
|
|
||||||
}
|
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
|
||||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
return false
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
return key, false
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
|
||||||
conn.ID == id &&
|
|
||||||
conn.Sequence == seq
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanupRoutine() {
|
// TrackOutbound records an outbound ICMP connection
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
|
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
|
||||||
|
// non echo requests don't need tracking
|
||||||
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||||
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := ICMPConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.tickerCancel()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.cleanupTicker.C:
|
case <-t.cleanupTicker.C:
|
||||||
t.cleanup()
|
t.cleanup()
|
||||||
case <-t.done:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanup() {
|
func (t *ICMPTracker) cleanup() {
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *ICMPTracker) Close() {
|
func (t *ICMPTracker) Close() {
|
||||||
t.cleanupTicker.Stop()
|
t.tickerCancel()
|
||||||
close(t.done)
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeICMPKey creates an ICMP connection key
|
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
return ICMPConnKey{
|
FlowID: conn.FlowId,
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
Type: typ,
|
||||||
DstIP: MakeIPAddr(dstIP),
|
RuleID: ruleID,
|
||||||
ID: id,
|
Direction: conn.Direction,
|
||||||
Sequence: seq,
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
}
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
ICMPType: conn.ICMPType,
|
||||||
|
ICMPCode: conn.ICMPCode,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeStart,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: direction,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
fields.RxPackets = 1
|
||||||
|
fields.RxBytes = uint64(size)
|
||||||
|
} else {
|
||||||
|
fields.TxPackets = 1
|
||||||
|
fields.TxBytes = uint64(size)
|
||||||
|
}
|
||||||
|
t.flowLogger.StoreEvent(fields)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,39 +1,39 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,16 @@ package conntrack
|
|||||||
// TODO: Send RST packets for invalid/timed-out connections
|
// TODO: Send RST packets for invalid/timed-out connections
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -39,6 +43,35 @@ const (
|
|||||||
// TCPState represents the state of a TCP connection
|
// TCPState represents the state of a TCP connection
|
||||||
type TCPState int
|
type TCPState int
|
||||||
|
|
||||||
|
func (s TCPState) String() string {
|
||||||
|
switch s {
|
||||||
|
case TCPStateNew:
|
||||||
|
return "New"
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return "SYN Sent"
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return "SYN Received"
|
||||||
|
case TCPStateEstablished:
|
||||||
|
return "Established"
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return "FIN Wait 1"
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return "FIN Wait 2"
|
||||||
|
case TCPStateClosing:
|
||||||
|
return "Closing"
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
return "Time Wait"
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return "Close Wait"
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return "Last ACK"
|
||||||
|
case TCPStateClosed:
|
||||||
|
return "Closed"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPStateNew TCPState = iota
|
TCPStateNew TCPState = iota
|
||||||
TCPStateSynSent
|
TCPStateSynSent
|
||||||
@@ -53,19 +86,14 @@ const (
|
|||||||
TCPStateClosed
|
TCPStateClosed
|
||||||
)
|
)
|
||||||
|
|
||||||
// TCPConnKey uniquely identifies a TCP connection
|
|
||||||
type TCPConnKey struct {
|
|
||||||
SrcIP [16]byte
|
|
||||||
DstIP [16]byte
|
|
||||||
SrcPort uint16
|
|
||||||
DstPort uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCPConnTrack represents a TCP connection state
|
// TCPConnTrack represents a TCP connection state
|
||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
State TCPState
|
State TCPState
|
||||||
established atomic.Bool
|
established atomic.Bool
|
||||||
|
tombstone atomic.Bool
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,78 +107,126 @@ func (t *TCPConnTrack) SetEstablished(state bool) {
|
|||||||
t.established.Store(state)
|
t.established.Store(state)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsTombstone safely checks if the connection is marked for deletion
|
||||||
|
func (t *TCPConnTrack) IsTombstone() bool {
|
||||||
|
return t.tombstone.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTombstone safely marks the connection for deletion
|
||||||
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
|
t.tombstone.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
type TCPTracker struct {
|
type TCPTracker struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
connections map[ConnKey]*TCPConnTrack
|
connections map[ConnKey]*TCPConnTrack
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
done chan struct{}
|
tickerCancel context.CancelFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultTCPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
tracker := &TCPTracker{
|
tracker := &TCPTracker{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*TCPConnTrack),
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: cancel,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
key := ConnKey{
|
||||||
// Create key before lock
|
SrcIP: srcIP,
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
if !exists {
|
t.mutex.RUnlock()
|
||||||
// Use preallocated IPs
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &TCPConnTrack{
|
if exists {
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||||
|
conn.Unlock()
|
||||||
|
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound TCP connection
|
||||||
|
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
|
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &TCPConnTrack{
|
||||||
BaseConnTrack: BaseConnTrack{
|
BaseConnTrack: BaseConnTrack{
|
||||||
SourceIP: srcIPCopy,
|
FlowId: uuid.New(),
|
||||||
DestIP: dstIPCopy,
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
SourcePort: srcPort,
|
SourcePort: srcPort,
|
||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
},
|
|
||||||
State: TCPStateNew,
|
|
||||||
}
|
}
|
||||||
conn.UpdateLastSeen()
|
|
||||||
conn.established.Store(false)
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
conn.established.Store(false)
|
||||||
}
|
conn.tombstone.Store(false)
|
||||||
|
|
||||||
|
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||||
|
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
// Lock individual connection for state update
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
conn.Lock()
|
|
||||||
t.updateState(conn, flags, true)
|
|
||||||
conn.Unlock()
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||||
if !isValidFlagCombination(flags) {
|
key := ConnKey{
|
||||||
return false
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
@@ -159,22 +235,26 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle RST packets
|
// Handle RST flag specially - it always causes transition to closed
|
||||||
if flags&TCPRst != 0 {
|
if flags&TCPRst != 0 {
|
||||||
conn.Lock()
|
if conn.IsTombstone() {
|
||||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
conn.Unlock()
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
conn.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Lock()
|
conn.Lock()
|
||||||
t.updateState(conn, flags, false)
|
conn.SetTombstone()
|
||||||
conn.UpdateLastSeen()
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
conn.Unlock()
|
||||||
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection reset: %s", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(key, conn, flags, false)
|
||||||
isEstablished := conn.IsEstablished()
|
isEstablished := conn.IsEstablished()
|
||||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||||
conn.Unlock()
|
conn.Unlock()
|
||||||
@@ -183,18 +263,17 @@ func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates the TCP connection state based on flags
|
// updateState updates the TCP connection state based on flags
|
||||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||||
// Handle RST flag specially - it always causes transition to closed
|
conn.UpdateLastSeen()
|
||||||
if flags&TCPRst != 0 {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
state := conn.State
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
defer func() {
|
||||||
return
|
if state != conn.State {
|
||||||
|
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
switch conn.State {
|
switch state {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
conn.State = TCPStateSynSent
|
conn.State = TCPStateSynSent
|
||||||
@@ -203,11 +282,11 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
if isOutbound {
|
if isOutbound {
|
||||||
conn.State = TCPStateSynReceived
|
|
||||||
} else {
|
|
||||||
// Simultaneous open
|
|
||||||
conn.State = TCPStateEstablished
|
conn.State = TCPStateEstablished
|
||||||
conn.SetEstablished(true)
|
conn.SetEstablished(true)
|
||||||
|
} else {
|
||||||
|
// Simultaneous open
|
||||||
|
conn.State = TCPStateSynReceived
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -225,22 +304,32 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
conn.State = TCPStateCloseWait
|
conn.State = TCPStateCloseWait
|
||||||
}
|
}
|
||||||
conn.SetEstablished(false)
|
conn.SetEstablished(false)
|
||||||
|
} else if flags&TCPRst != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
switch {
|
switch {
|
||||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
// Simultaneous close - both sides sent FIN
|
|
||||||
conn.State = TCPStateClosing
|
conn.State = TCPStateClosing
|
||||||
case flags&TCPFin != 0:
|
case flags&TCPFin != 0:
|
||||||
conn.State = TCPStateFinWait2
|
conn.State = TCPStateFinWait2
|
||||||
case flags&TCPAck != 0:
|
case flags&TCPAck != 0:
|
||||||
conn.State = TCPStateFinWait2
|
conn.State = TCPStateFinWait2
|
||||||
|
case flags&TCPRst != 0:
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait2:
|
case TCPStateFinWait2:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection %s completed", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateClosing:
|
case TCPStateClosing:
|
||||||
@@ -248,8 +337,8 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
conn.State = TCPStateTimeWait
|
conn.State = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
// Keep established = false from previous state
|
||||||
|
|
||||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
@@ -260,17 +349,12 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
|
||||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
// Send close event for gracefully closed connections
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
|
||||||
// This is handled by the cleanup routine
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -315,12 +399,14 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPTracker) cleanupRoutine() {
|
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.cleanupTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.cleanupTicker.C:
|
case <-t.cleanupTicker.C:
|
||||||
t.cleanup()
|
t.cleanup()
|
||||||
case <-t.done:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -331,6 +417,12 @@ func (t *TCPTracker) cleanup() {
|
|||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
|
if conn.IsTombstone() {
|
||||||
|
// Clean up tombstoned connections without sending an event
|
||||||
|
delete(t.connections, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
switch {
|
switch {
|
||||||
case conn.State == TCPStateTimeWait:
|
case conn.State == TCPStateTimeWait:
|
||||||
@@ -341,29 +433,26 @@ func (t *TCPTracker) cleanup() {
|
|||||||
timeout = TCPHandshakeTimeout
|
timeout = TCPHandshakeTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
lastSeen := conn.GetLastSeen()
|
if conn.timeoutExceeded(timeout) {
|
||||||
if time.Since(lastSeen) > timeout {
|
|
||||||
// Return IPs to pool
|
// Return IPs to pool
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||||
|
|
||||||
|
// event already handled by state change
|
||||||
|
if conn.State != TCPStateTimeWait {
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *TCPTracker) Close() {
|
func (t *TCPTracker) Close() {
|
||||||
t.cleanupTicker.Stop()
|
t.tickerCancel()
|
||||||
close(t.done)
|
|
||||||
|
|
||||||
// Clean up all remaining IPs
|
// Clean up all remaining IPs
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -381,3 +470,21 @@ func isValidFlagCombination(flags uint8) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -9,11 +9,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -58,7 +58,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -76,17 +76,17 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Send initial SYN
|
// Send initial SYN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
// Receive SYN-ACK
|
// Receive SYN-ACK
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
// Send ACK
|
// Send ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
// Test data transfer
|
// Test data transfer
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||||
require.True(t, valid, "Data should be allowed after handshake")
|
require.True(t, valid, "Data should be allowed after handshake")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -99,18 +99,18 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Send FIN
|
// Send FIN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
// Receive ACK for FIN
|
// Receive ACK for FIN
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "ACK for FIN should be allowed")
|
require.True(t, valid, "ACK for FIN should be allowed")
|
||||||
|
|
||||||
// Receive FIN from other side
|
// Receive FIN from other side
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "FIN should be allowed")
|
require.True(t, valid, "FIN should be allowed")
|
||||||
|
|
||||||
// Send final ACK
|
// Send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -122,7 +122,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Receive RST
|
// Receive RST
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
require.True(t, valid, "RST should be allowed for established connection")
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||||
@@ -138,13 +138,13 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Both sides send FIN+ACK
|
// Both sides send FIN+ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||||
|
|
||||||
// Both sides send final ACK
|
// Both sides send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "Final ACKs should be allowed")
|
require.True(t, valid, "Final ACKs should be allowed")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -162,11 +162,11 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -181,12 +181,12 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST in established",
|
name: "RST in established",
|
||||||
setupState: func() {
|
setupState: func() {
|
||||||
// Establish connection first
|
// Establish connection first
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: true,
|
wantValid: true,
|
||||||
desc: "Should accept RST for established connection",
|
desc: "Should accept RST for established connection",
|
||||||
@@ -195,7 +195,7 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST without connection",
|
name: "RST without connection",
|
||||||
setupState: func() {},
|
setupState: func() {},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: false,
|
wantValid: false,
|
||||||
desc: "Should reject RST without connection",
|
desc: "Should reject RST without connection",
|
||||||
@@ -208,7 +208,12 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
tt.sendRST()
|
tt.sendRST()
|
||||||
|
|
||||||
// Verify connection state is as expected
|
// Verify connection state is as expected
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn := tracker.connections[key]
|
conn := tracker.connections[key]
|
||||||
if tt.wantValid {
|
if tt.wantValid {
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
@@ -220,63 +225,63 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper to establish a TCP connection
|
// Helper to establish a TCP connection
|
||||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
i := 0
|
i := 0
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
if i%2 == 0 {
|
if i%2 == 0 {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
} else {
|
} else {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||||
}
|
}
|
||||||
i++
|
i++
|
||||||
}
|
}
|
||||||
@@ -287,14 +292,14 @@ func BenchmarkTCPTracker(b *testing.B) {
|
|||||||
// Benchmark connection cleanup
|
// Benchmark connection cleanup
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
// Pre-populate with expired connections
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
for i := 0; i < 10000; i++ {
|
for i := 0; i < 10000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wait for connections to expire
|
// Wait for connections to expire
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -18,6 +22,8 @@ const (
|
|||||||
// UDPConnTrack represents a UDP connection state
|
// UDPConnTrack represents a UDP connection state
|
||||||
type UDPConnTrack struct {
|
type UDPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPTracker manages UDP connection states
|
// UDPTracker manages UDP connection states
|
||||||
@@ -26,89 +32,125 @@ type UDPTracker struct {
|
|||||||
connections map[ConnKey]*UDPConnTrack
|
connections map[ConnKey]*UDPConnTrack
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
done chan struct{}
|
flowLogger nftypes.FlowLogger
|
||||||
ipPool *PreallocatedIPs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
tracker := &UDPTracker{
|
tracker := &UDPTracker{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
connections: make(map[ConnKey]*UDPConnTrack),
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.mutex.Lock()
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &UDPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New UDP connection: %v", conn)
|
|
||||||
}
|
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// TrackInbound records an inbound UDP connection
|
||||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
|
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
conn.UpdateLastSeen()
|
||||||
return false
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
}
|
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
return true
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
|
||||||
conn.DestPort == srcPort &&
|
|
||||||
conn.SourcePort == dstPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRoutine periodically removes stale connections
|
// cleanupRoutine periodically removes stale connections
|
||||||
func (t *UDPTracker) cleanupRoutine() {
|
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.cleanupTicker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-t.cleanupTicker.C:
|
case <-t.cleanupTicker.C:
|
||||||
t.cleanup()
|
t.cleanup()
|
||||||
case <-t.done:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -120,44 +162,58 @@ func (t *UDPTracker) cleanup() {
|
|||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close stops the cleanup routine and releases resources
|
// Close stops the cleanup routine and releases resources
|
||||||
func (t *UDPTracker) Close() {
|
func (t *UDPTracker) Close() {
|
||||||
t.cleanupTicker.Stop()
|
t.tickerCancel()
|
||||||
close(t.done)
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnection safely retrieves a connection state
|
// GetConnection safely retrieves a connection state
|
||||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
defer t.mutex.RUnlock()
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
conn, exists := t.connections[key]
|
SrcIP: srcIP,
|
||||||
if !exists {
|
DstIP: dstIP,
|
||||||
return nil, false
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
}
|
}
|
||||||
|
conn, exists := t.connections[key]
|
||||||
return conn, true
|
return conn, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout returns the configured timeout duration for the tracker
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
func (t *UDPTracker) Timeout() time.Duration {
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
return t.timeout
|
return t.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -29,54 +30,59 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout, logger)
|
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
assert.NotNil(t, tracker.cleanupTicker)
|
assert.NotNil(t, tracker.cleanupTicker)
|
||||||
assert.NotNil(t, tracker.done)
|
assert.NotNil(t, tracker.tickerCancel)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn, exists := tracker.connections[key]
|
conn, exists := tracker.connections[key]
|
||||||
require.True(t, exists)
|
require.True(t, exists)
|
||||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1*time.Second, logger)
|
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
// Track outbound connection
|
// Track outbound connection
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
sleep time.Duration
|
sleep time.Duration
|
||||||
@@ -93,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid source IP",
|
name: "invalid source IP",
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: srcIP,
|
dstIP: srcIP,
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
@@ -103,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid destination IP",
|
name: "invalid destination IP",
|
||||||
srcIP: dstIP,
|
srcIP: dstIP,
|
||||||
dstIP: net.ParseIP("192.168.1.4"),
|
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
sleep: 0,
|
sleep: 0,
|
||||||
@@ -143,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
if tt.sleep > 0 {
|
if tt.sleep > 0 {
|
||||||
time.Sleep(tt.sleep)
|
time.Sleep(tt.sleep)
|
||||||
}
|
}
|
||||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||||
assert.Equal(t, tt.want, got)
|
assert.Equal(t, tt.want, got)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -154,42 +160,45 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
timeout := 50 * time.Millisecond
|
timeout := 50 * time.Millisecond
|
||||||
cleanupInterval := 25 * time.Millisecond
|
cleanupInterval := 25 * time.Millisecond
|
||||||
|
|
||||||
|
ctx, tickerCancel := context.WithCancel(context.Background())
|
||||||
|
defer tickerCancel()
|
||||||
|
|
||||||
// Create tracker with custom cleanup interval
|
// Create tracker with custom cleanup interval
|
||||||
tracker := &UDPTracker{
|
tracker := &UDPTracker{
|
||||||
connections: make(map[ConnKey]*UDPConnTrack),
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
done: make(chan struct{}),
|
tickerCancel: tickerCancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
go tracker.cleanupRoutine()
|
go tracker.cleanupRoutine(ctx)
|
||||||
|
|
||||||
// Add some connections
|
// Add some connections
|
||||||
connections := []struct {
|
connections := []struct {
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.2"),
|
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
dstIP: net.ParseIP("192.168.1.3"),
|
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||||
srcPort: 12345,
|
srcPort: 12345,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: net.ParseIP("192.168.1.5"),
|
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||||
srcPort: 12346,
|
srcPort: 12346,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, conn := range connections {
|
for _, conn := range connections {
|
||||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify initial connections
|
// Verify initial connections
|
||||||
@@ -211,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
|||||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type epID stack.TransportEndpointID
|
||||||
|
|
||||||
|
func (i epID) String() string {
|
||||||
|
// src and remote is swapped
|
||||||
|
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -29,6 +30,7 @@ const (
|
|||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
@@ -38,7 +40,7 @@ type Forwarder struct {
|
|||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &Forwarder{
|
f := &Forwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
stack: s,
|
stack: s,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
udpForwarder: newUDPForwarder(mtu, logger),
|
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
|
|||||||
@@ -3,14 +3,30 @@ package forwarder
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleICMP handles ICMP packets from the network stack
|
// handleICMP handles ICMP packets from the network stack
|
||||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
icmpType := uint8(icmpHdr.Type())
|
||||||
|
icmpCode := uint8(icmpHdr.Code())
|
||||||
|
|
||||||
|
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
// TODO: support non-root
|
// TODO: support non-root
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||||
|
|
||||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
return false
|
return false
|
||||||
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
dst := &net.IPAddr{IP: dstIP}
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
// Get the complete ICMP message (header + data)
|
|
||||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
payload := fullPacket.AsSlice()
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||||
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
// For Echo Requests, send and handle response
|
// For Echo Requests, send and handle response
|
||||||
switch icmpHdr.Type() {
|
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||||
case header.ICMPv4Echo:
|
f.handleEchoResponse(icmpHdr, conn, id)
|
||||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||||
case header.ICMPv4EchoReply:
|
|
||||||
// dont process our own replies
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||||
_, err = conn.WriteTo(payload, dst)
|
|
||||||
if err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, f.endpoint.mtu)
|
response := make([]byte, f.endpoint.mtu)
|
||||||
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
if !isTimeout(err) {
|
if !isTimeout(err) {
|
||||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
}
|
}
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
|
|
||||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
return true
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||||
return true
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
|
// TODO: get packets/bytes
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,24 +5,38 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleTCP is called by the TCP forwarder for new connections.
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
id := r.ID()
|
id := r.ID()
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
inConn := gonet.NewTCPConn(&wq, ep)
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyTCP(id, inConn, outConn, ep)
|
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
}
|
}
|
||||||
ep.Close()
|
ep.Close()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Create context for managing the proxy goroutines
|
// Create context for managing the proxy goroutines
|
||||||
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
@@ -16,6 +18,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,11 +31,13 @@ type udpPacketConn struct {
|
|||||||
lastSeen atomic.Int64
|
lastSeen atomic.Int64
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ep tcpip.Endpoint
|
ep tcpip.Endpoint
|
||||||
|
flowID uuid.UUID
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpForwarder struct {
|
type udpForwarder struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
bufPool sync.Pool
|
bufPool sync.Pool
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -44,10 +49,11 @@ type idleConn struct {
|
|||||||
conn *udpPacketConn
|
conn *udpPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &udpForwarder{
|
f := &udpForwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
|
|||||||
for id, conn := range f.conns {
|
for id, conn := range f.conns {
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
if err := conn.conn.Close(); err != nil {
|
if err := conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := conn.outConn.Close(); err != nil {
|
if err := conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.ep.Close()
|
conn.ep.Close()
|
||||||
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
|
|||||||
for _, idle := range idleConns {
|
for _, idle := range idleConns {
|
||||||
idle.conn.cancel()
|
idle.conn.cancel()
|
||||||
if err := idle.conn.conn.Close(); err != nil {
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
if err := idle.conn.outConn.Close(); err != nil {
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idle.conn.ep.Close()
|
idle.conn.ep.Close()
|
||||||
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
|
|||||||
delete(f.conns, idle.id)
|
delete(f.conns, idle.id)
|
||||||
f.Unlock()
|
f.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
_, exists := f.udpForwarder.conns[id]
|
_, exists := f.udpForwarder.conns[id]
|
||||||
f.udpForwarder.RUnlock()
|
f.udpForwarder.RUnlock()
|
||||||
if exists {
|
if exists {
|
||||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||||
// TODO: Send ICMP error message
|
// TODO: Send ICMP error message
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if epErr != nil {
|
if epErr != nil {
|
||||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
outConn: outConn,
|
outConn: outConn,
|
||||||
cancel: connCancel,
|
cancel: connCancel,
|
||||||
ep: ep,
|
ep: ep,
|
||||||
|
flowID: flowID,
|
||||||
}
|
}
|
||||||
pConn.updateLastSeen()
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[id] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
defer func() {
|
defer func() {
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := pConn.conn.Close(); err != nil {
|
if err := pConn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := pConn.outConn.Close(); err != nil {
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ep.Close()
|
ep.Close()
|
||||||
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
delete(f.udpForwarder.conns, id)
|
delete(f.udpForwarder.conns, id)
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *udpPacketConn) updateLastSeen() {
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
c.lastSeen.Store(time.Now().UnixNano())
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -31,13 +32,9 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
ipv4 := ip.To4()
|
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||||
if ipv4 == nil {
|
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||||
return false
|
|
||||||
}
|
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
|
||||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,12 +119,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if ip.Is4() {
|
||||||
return m.checkBitmapBit(ipv4)
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -2,90 +2,91 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLocalIPManager(t *testing.T) {
|
func TestLocalIPManager(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAddr iface.WGAddress
|
setupAddr wgaddr.Address
|
||||||
testIP net.IP
|
testIP netip.Addr
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Localhost range",
|
name: "Localhost range",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Localhost standard address",
|
name: "Localhost standard address",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Localhost range edge",
|
name: "Localhost range edge",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Local IP matches",
|
name: "Local IP matches",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Local IP doesn't match",
|
name: "Local IP doesn't match",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address",
|
||||||
setupAddr: iface.WGAddress{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("fe80::1"),
|
IP: net.ParseIP("fe80::1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("fe80::"),
|
IP: net.ParseIP("fe80::"),
|
||||||
Mask: net.CIDRMask(64, 128),
|
Mask: net.CIDRMask(64, 128),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -95,7 +96,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
manager := newLocalIPManager()
|
manager := newLocalIPManager()
|
||||||
|
|
||||||
mock := &IFaceMock{
|
mock := &IFaceMock{
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return tt.setupAddr
|
return tt.setupAddr
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -174,7 +175,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
|||||||
t.Logf("Testing %d IPs", len(tests))
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.ip, func(t *testing.T) {
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -13,13 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2 // 2KB per message
|
maxMessageSize = 1024 * 2
|
||||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
logChannelSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
// Level represents log severity
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
|
|||||||
LevelTrace: "TRAC",
|
LevelTrace: "TRAC",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type logMessage struct {
|
||||||
|
level Level
|
||||||
|
format string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
// Logger is a high-performance, non-blocking logger
|
// Logger is a high-performance, non-blocking logger
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
output io.Writer
|
output io.Writer
|
||||||
level atomic.Uint32
|
level atomic.Uint32
|
||||||
buffer *ringBuffer
|
msgChannel chan logMessage
|
||||||
shutdown chan struct{}
|
shutdown chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// Reusable buffer pool for formatting messages
|
|
||||||
bufPool sync.Pool
|
bufPool sync.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
buffer: newRingBuffer(bufferSize),
|
msgChannel: make(chan logMessage, logChannelSize),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() any {
|
||||||
// Pre-allocate buffer for message formatting
|
|
||||||
b := make([]byte, 0, maxMessageSize)
|
b := make([]byte, 0, maxMessageSize)
|
||||||
return &b
|
return &b
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logrusLevel := logrusLogger.GetLevel()
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
l.level.Store(uint32(logrusLevel))
|
l.level.Store(uint32(logrusLevel))
|
||||||
level := levelStrings[Level(logrusLevel)]
|
level := levelStrings[Level(logrusLevel)]
|
||||||
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLevel sets the logging level
|
||||||
func (l *Logger) SetLevel(level Level) {
|
func (l *Logger) SetLevel(level Level) {
|
||||||
l.level.Store(uint32(level))
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
func (l *Logger) log(level Level, format string, args ...any) {
|
||||||
*buf = (*buf)[:0]
|
select {
|
||||||
|
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||||
// Timestamp
|
default:
|
||||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Level
|
|
||||||
*buf = append(*buf, levelStrings[level]...)
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Message
|
|
||||||
if len(args) > 0 {
|
|
||||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
|
||||||
} else {
|
|
||||||
*buf = append(*buf, format...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*buf = append(*buf, '\n')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
// Error logs a message at error level
|
||||||
bufp := l.bufPool.Get().(*[]byte)
|
func (l *Logger) Error(format string, args ...any) {
|
||||||
l.formatMessage(bufp, level, format, args...)
|
|
||||||
|
|
||||||
if len(*bufp) > maxMessageSize {
|
|
||||||
*bufp = (*bufp)[:maxMessageSize]
|
|
||||||
}
|
|
||||||
_, _ = l.buffer.Write(*bufp)
|
|
||||||
|
|
||||||
l.bufPool.Put(bufp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Error(format string, args ...interface{}) {
|
|
||||||
if l.level.Load() >= uint32(LevelError) {
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
l.log(LevelError, format, args...)
|
l.log(LevelError, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
// Warn logs a message at warning level
|
||||||
|
func (l *Logger) Warn(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelWarn) {
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
l.log(LevelWarn, format, args...)
|
l.log(LevelWarn, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Info(format string, args ...interface{}) {
|
// Info logs a message at info level
|
||||||
|
func (l *Logger) Info(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelInfo) {
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
l.log(LevelInfo, format, args...)
|
l.log(LevelInfo, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
// Debug logs a message at debug level
|
||||||
|
func (l *Logger) Debug(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelDebug) {
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
l.log(LevelDebug, format, args...)
|
l.log(LevelDebug, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
// Trace logs a message at trace level
|
||||||
|
func (l *Logger) Trace(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelTrace) {
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
l.log(LevelTrace, format, args...)
|
l.log(LevelTrace, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker periodically flushes the buffer
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
var msg string
|
||||||
|
if len(args) > 0 {
|
||||||
|
msg = fmt.Sprintf(format, args...)
|
||||||
|
} else {
|
||||||
|
msg = format
|
||||||
|
}
|
||||||
|
*buf = append(*buf, msg...)
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
|
||||||
|
if len(*buf) > maxMessageSize {
|
||||||
|
*buf = (*buf)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessage handles a single log message and adds it to the buffer
|
||||||
|
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
defer l.bufPool.Put(bufp)
|
||||||
|
|
||||||
|
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||||
|
|
||||||
|
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
*buffer = append(*buffer, *bufp...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushBuffer writes the accumulated buffer to output
|
||||||
|
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||||
|
if len(*buffer) > 0 {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatch processes as many messages as possible without blocking
|
||||||
|
func (l *Logger) processBatch(buffer *[]byte) {
|
||||||
|
for len(*buffer) < maxBatchSize {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||||
|
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(l.msgChannel) == 0 {
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is the main goroutine that processes log messages
|
||||||
func (l *Logger) worker() {
|
func (l *Logger) worker() {
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
|
||||||
ticker := time.NewTicker(defaultFlushInterval)
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
buf := make([]byte, 0, maxBatchSize)
|
buffer := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-l.shutdown:
|
case <-l.shutdown:
|
||||||
|
l.handleShutdown(&buffer)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Read accumulated messages
|
l.flushBuffer(&buffer)
|
||||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
case msg := <-l.msgChannel:
|
||||||
if n == 0 {
|
l.processMessage(msg, &buffer)
|
||||||
continue
|
l.processBatch(&buffer)
|
||||||
}
|
|
||||||
|
|
||||||
// Write batch
|
|
||||||
_, _ = l.output.Write(buf[:n])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type discard struct{}
|
||||||
|
|
||||||
|
func (d *discard) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
simpleMessage := "Connection established"
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4 // TCPStateEstablished
|
||||||
|
|
||||||
|
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||||
|
protocol := "TCP"
|
||||||
|
direction := "outbound"
|
||||||
|
flags := uint16(0x18) // ACK + PSH
|
||||||
|
sequence := uint32(123456789)
|
||||||
|
acknowledged := uint32(987654321)
|
||||||
|
payloadSize := 1460
|
||||||
|
fragmented := false
|
||||||
|
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||||
|
|
||||||
|
b.Run("SimpleMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(simpleMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ComplexMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||||
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||||
|
func BenchmarkLoggerBurst(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestLogger() *log.Logger {
|
||||||
|
logrusLogger := logrus.New()
|
||||||
|
logrusLogger.SetOutput(&discard{})
|
||||||
|
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||||
|
return log.NewFromLogrus(logrusLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupLogger(logger *log.Logger) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = logger.Stop(ctx)
|
||||||
|
}
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
package log
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// ringBuffer is a simple ring buffer implementation
|
|
||||||
type ringBuffer struct {
|
|
||||||
buf []byte
|
|
||||||
size int
|
|
||||||
r, w int64 // Read and write positions
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRingBuffer(size int) *ringBuffer {
|
|
||||||
return &ringBuffer{
|
|
||||||
buf: make([]byte, size),
|
|
||||||
size: size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
|
||||||
if len(p) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if len(p) > r.size {
|
|
||||||
p = p[:r.size]
|
|
||||||
}
|
|
||||||
|
|
||||||
n = len(p)
|
|
||||||
|
|
||||||
// Write data, handling wrap-around
|
|
||||||
pos := int(r.w % int64(r.size))
|
|
||||||
writeLen := min(len(p), r.size-pos)
|
|
||||||
copy(r.buf[pos:], p[:writeLen])
|
|
||||||
|
|
||||||
// If we have more data and need to wrap around
|
|
||||||
if writeLen < len(p) {
|
|
||||||
copy(r.buf, p[writeLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update write position
|
|
||||||
r.w += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if r.w == r.r {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate available data accounting for wraparound
|
|
||||||
available := int(r.w - r.r)
|
|
||||||
if available < 0 {
|
|
||||||
available += r.size
|
|
||||||
}
|
|
||||||
available = min(available, r.size)
|
|
||||||
|
|
||||||
// Limit read to buffer size
|
|
||||||
toRead := min(available, len(p))
|
|
||||||
if toRead == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read data, handling wrap-around
|
|
||||||
pos := int(r.r % int64(r.size))
|
|
||||||
readLen := min(toRead, r.size-pos)
|
|
||||||
n = copy(p, r.buf[pos:pos+readLen])
|
|
||||||
|
|
||||||
// If we need more data and need to wrap around
|
|
||||||
if readLen < toRead {
|
|
||||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update read position
|
|
||||||
r.r += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -12,25 +11,26 @@ import (
|
|||||||
// PeerRule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type PeerRule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
mgmtId []byte
|
||||||
|
ip netip.Addr
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
sPort *firewall.Port
|
sPort *firewall.Port
|
||||||
dPort *firewall.Port
|
dPort *firewall.Port
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *PeerRule) GetRuleID() string {
|
func (r *PeerRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteRule struct {
|
type RouteRule struct {
|
||||||
id string
|
id string
|
||||||
|
mgmtId []byte
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
destination netip.Prefix
|
destination netip.Prefix
|
||||||
proto firewall.Protocol
|
proto firewall.Protocol
|
||||||
@@ -39,7 +39,7 @@ type RouteRule struct {
|
|||||||
action firewall.Action
|
action firewall.Action
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *RouteRule) GetRuleID() string {
|
func (r *RouteRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketTrace struct {
|
type PacketTrace struct {
|
||||||
SourceIP net.IP
|
SourceIP netip.Addr
|
||||||
DestinationIP net.IP
|
DestinationIP netip.Addr
|
||||||
Protocol string
|
Protocol string
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestinationPort uint16
|
DestinationPort uint16
|
||||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketBuilder struct {
|
type PacketBuilder struct {
|
||||||
SrcIP net.IP
|
SrcIP netip.Addr
|
||||||
DstIP net.IP
|
DstIP netip.Addr
|
||||||
Protocol fw.Protocol
|
Protocol fw.Protocol
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
|||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
SrcIP: p.SrcIP,
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
DstIP: p.DstIP,
|
DstIP: p.DstIP.AsSlice(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !m.handleRouting(trace) {
|
if !m.handleRouting(trace) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return m.handleNativeRouter(trace)
|
return m.handleNativeRouter(trace)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||||
msg := "No existing connection found"
|
msg := "No existing connection found"
|
||||||
if allowed {
|
if allowed {
|
||||||
msg = m.buildConntrackStateMessage(d)
|
msg = m.buildConntrackStateMessage(d)
|
||||||
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
if !m.localForwarding {
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
strRuleId := "<no id>"
|
||||||
|
if ruleId != nil {
|
||||||
|
strRuleId = string(ruleId)
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||||
|
if blocked {
|
||||||
|
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||||
|
trace.AddResult(StagePeerACL, msg, false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
trace.AddResult(StagePeerACL, msg, true)
|
||||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
|
||||||
|
|
||||||
msg := "Allowed by peer ACL rules"
|
|
||||||
if blocked {
|
|
||||||
msg = "Blocked by peer ACL rules"
|
|
||||||
}
|
|
||||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
|
||||||
|
|
||||||
|
// Handle netstack mode
|
||||||
if m.netstack {
|
if m.netstack {
|
||||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
switch {
|
||||||
|
case !m.localForwarding:
|
||||||
|
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||||
|
case m.forwarder.Load() != nil:
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
default:
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
// In normal mode, packets are allowed through for local delivery
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
return false
|
return false
|
||||||
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
|||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||||
proto := getProtocolFromPacket(d)
|
proto, _ := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
msg := "Allowed by route ACLs"
|
strId := string(id)
|
||||||
|
if id == nil {
|
||||||
|
strId = "<no id>"
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||||
if !allowed {
|
if !allowed {
|
||||||
msg = "Blocked by route ACLs"
|
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||||
}
|
}
|
||||||
trace.AddResult(StageRouteACL, msg, allowed)
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
if allowed && m.forwarder != nil {
|
if allowed && m.forwarder.Load() != nil {
|
||||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.processOutgoingHooks(packetData)
|
dropped := m.processOutgoingHooks(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
440
client/firewall/uspfilter/tracer_test.go
Normal file
440
client/firewall/uspfilter/tracer_test.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||||
|
t.Logf("Trace results: %v", trace.Results)
|
||||||
|
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||||
|
for _, result := range trace.Results {
|
||||||
|
actualStages = append(actualStages, result.Stage)
|
||||||
|
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||||
|
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||||
|
lastResult := trace.Results[len(trace.Results)-1]
|
||||||
|
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||||
|
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTracePacket(t *testing.T) {
|
||||||
|
setupTracerTest := func(statefulMode bool) *Manager {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !statefulMode {
|
||||||
|
m.stateful = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
builder := &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocol == "tcp" {
|
||||||
|
builder.TCPState = &TCPState{SYN: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder
|
||||||
|
}
|
||||||
|
|
||||||
|
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
return &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: "icmp",
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*Manager)
|
||||||
|
packetBuilder func() *PacketBuilder
|
||||||
|
expectedStages []PacketStage
|
||||||
|
expectedAllow bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = true
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithoutForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_NativeRouter",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(true)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_RoutingDisabled",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ConnectionTracking_Hit",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||||
|
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||||
|
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||||
|
return pb
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OutboundTraffic",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPEchoRequest",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPDestinationUnreachable",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithoutHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolUDP
|
||||||
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
hookFunc := func([]byte) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "StatefulDisabled_NoTracking",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.stateful = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
m := setupTracerTest(true)
|
||||||
|
|
||||||
|
tc.setup(m)
|
||||||
|
|
||||||
|
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||||
|
"100.10.0.100 should be recognized as a local IP")
|
||||||
|
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||||
|
"172.17.0.2 should not be recognized as a local IP")
|
||||||
|
|
||||||
|
pb := tc.packetBuilder()
|
||||||
|
|
||||||
|
trace, err := m.TracePacketFromBuilder(pb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifyTraceStages(t, trace, tc.expectedStages)
|
||||||
|
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
@@ -22,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -42,6 +44,8 @@ const (
|
|||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||||
|
|
||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]PeerRule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
@@ -63,9 +67,9 @@ func (r RouteRules) Sort() {
|
|||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
// outgoingRules is used for hooks only
|
// outgoingRules is used for hooks only
|
||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[netip.Addr]RuleSet
|
||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
@@ -77,9 +81,9 @@ type Manager struct {
|
|||||||
// indicates whether server routes are disabled
|
// indicates whether server routes are disabled
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
// indicates whether we forward packets not destined for ourselves
|
// indicates whether we forward packets not destined for ourselves
|
||||||
routingEnabled bool
|
routingEnabled atomic.Bool
|
||||||
// indicates whether we leave forwarding and filtering to the native firewall
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
nativeRouter bool
|
nativeRouter atomic.Bool
|
||||||
// indicates whether we track outbound connections
|
// indicates whether we track outbound connections
|
||||||
stateful bool
|
stateful bool
|
||||||
// indicates whether wireguards runs in netstack mode
|
// indicates whether wireguards runs in netstack mode
|
||||||
@@ -92,8 +96,9 @@ type Manager struct {
|
|||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
forwarder *forwarder.Forwarder
|
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -110,16 +115,16 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
return create(iface, nil, disableServerRoutes)
|
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
if nativeFirewall == nil {
|
if nativeFirewall == nil {
|
||||||
return nil, errors.New("native firewall is nil")
|
return nil, errors.New("native firewall is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -146,7 +151,7 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -164,17 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
nativeFirewall: nativeFirewall,
|
nativeFirewall: nativeFirewall,
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[netip.Addr]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
localipmanager: newLocalIPManager(),
|
localipmanager: newLocalIPManager(),
|
||||||
disableServerRoutes: disableServerRoutes,
|
disableServerRoutes: disableServerRoutes,
|
||||||
routingEnabled: false,
|
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
}
|
}
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
@@ -183,9 +189,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// netstack needs the forwarder for local traffic
|
// netstack needs the forwarder for local traffic
|
||||||
@@ -206,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||||
if m.forwarder == nil {
|
if m.forwarder.Load() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
@@ -216,6 +222,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
|||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
if _, err := m.AddRouteFiltering(
|
if _, err := m.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
wgPrefix,
|
wgPrefix,
|
||||||
firewall.ProtocolALL,
|
firewall.ProtocolALL,
|
||||||
@@ -249,20 +256,20 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case disableUspRouting:
|
case disableUspRouting:
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
log.Info("userspace routing is disabled")
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
case m.disableServerRoutes:
|
case m.disableServerRoutes:
|
||||||
// if server routes are disabled we will let packets pass to the native stack
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
case forceUserspaceRouter:
|
case forceUserspaceRouter:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
@@ -270,19 +277,19 @@ func (m *Manager) determineRouting() error {
|
|||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("native routing is enabled")
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
default:
|
default:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing enabled by default")
|
log.Info("userspace routing enabled by default")
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.routingEnabled && !m.nativeRouter {
|
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||||
return m.initForwarder()
|
return m.initForwarder()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -291,24 +298,24 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
// initForwarder initializes the forwarder, it disables routing on errors
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
func (m *Manager) initForwarder() error {
|
func (m *Manager) initForwarder() error {
|
||||||
if m.forwarder != nil {
|
if m.forwarder.Load() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
intf := m.wgIface.GetWGDevice()
|
intf := m.wgIface.GetWGDevice()
|
||||||
if intf == nil {
|
if intf == nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return errors.New("forwarding not supported")
|
return errors.New("forwarding not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return fmt.Errorf("create forwarder: %w", err)
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder = forwarder
|
m.forwarder.Store(forwarder)
|
||||||
|
|
||||||
log.Debug("forwarder initialized")
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
@@ -324,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -346,25 +353,31 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
_ string,
|
_ string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
|
// TODO: fix in upper layers
|
||||||
|
i, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
i = i.Unmap()
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
mgmtId: id,
|
||||||
|
ip: i,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
matchByIP: true,
|
matchByIP: true,
|
||||||
drop: action == firewall.ActionDrop,
|
drop: action == firewall.ActionDrop,
|
||||||
comment: comment,
|
|
||||||
}
|
}
|
||||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
if i.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
r.ip = ipNormalized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||||
@@ -389,15 +402,16 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
m.incomingRules[r.ip] = make(RuleSet)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -405,16 +419,15 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: ruleID,
|
||||||
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
destination: destination,
|
destination: destination,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
@@ -423,21 +436,23 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
action: action,
|
action: action,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
m.routeRules = append(m.routeRules, rule)
|
m.routeRules = append(m.routeRules, rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
ruleID := rule.GetRuleID()
|
ruleID := rule.ID()
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == ruleID
|
||||||
})
|
})
|
||||||
@@ -459,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
delete(m.incomingRules[r.ip], r.id)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -478,14 +493,30 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||||
return m.processOutgoingHooks(packetData)
|
return m.processOutgoingHooks(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||||
return m.dropFilter(packetData)
|
return m.dropFilter(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs updates the list of local IPs
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
@@ -493,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -509,52 +537,37 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track all protocols if stateful mode is enabled
|
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||||
if m.stateful {
|
return true
|
||||||
switch d.decoded[1] {
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeICMPv4:
|
|
||||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process UDP hooks even if stateful mode is disabled
|
if m.stateful {
|
||||||
if d.decoded[1] == layers.LayerTypeUDP {
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
return m.checkUDPHooks(d, dstIP, packetData)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||||
switch d.decoded[0] {
|
switch d.decoded[0] {
|
||||||
case layers.LayerTypeIPv4:
|
case layers.LayerTypeIPv4:
|
||||||
return d.ip4.SrcIP, d.ip4.DstIP
|
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
|
return src, dst
|
||||||
case layers.LayerTypeIPv6:
|
case layers.LayerTypeIPv6:
|
||||||
return d.ip6.SrcIP, d.ip6.DstIP
|
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||||
|
return src, dst
|
||||||
default:
|
default:
|
||||||
return nil, nil
|
return netip.Addr{}, netip.Addr{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|
||||||
flags := getTCPFlags(&d.tcp)
|
|
||||||
m.tcpTracker.TrackOutbound(
|
|
||||||
srcIP,
|
|
||||||
dstIP,
|
|
||||||
uint16(d.tcp.SrcPort),
|
|
||||||
uint16(d.tcp.DstPort),
|
|
||||||
flags,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||||
var flags uint8
|
var flags uint8
|
||||||
if tcp.SYN {
|
if tcp.SYN {
|
||||||
@@ -578,45 +591,70 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
|||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||||
m.udpTracker.TrackOutbound(
|
transport := d.decoded[1]
|
||||||
srcIP,
|
switch transport {
|
||||||
dstIP,
|
case layers.LayerTypeUDP:
|
||||||
uint16(d.udp.SrcPort),
|
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||||
uint16(d.udp.DstPort),
|
case layers.LayerTypeTCP:
|
||||||
)
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
|
||||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
transport := d.decoded[1]
|
||||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
switch transport {
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||||
|
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Check specific destination IP first
|
||||||
|
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
return rule.udpHook(packetData)
|
return rule.udpHook(packetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
// Check IPv4 unspecified address
|
||||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||||
m.icmpTracker.TrackOutbound(
|
for _, rule := range rules {
|
||||||
srcIP,
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
dstIP,
|
return rule.udpHook(packetData)
|
||||||
d.icmp4.Id,
|
|
||||||
d.icmp4.Seq,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check IPv6 unspecified address
|
||||||
|
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets.
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -625,19 +663,19 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.localipmanager.IsLocalIP(dstIP) {
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
||||||
@@ -645,10 +683,29 @@ func (m *Manager) dropFilter(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleLocalTraffic handles local traffic.
|
// handleLocalTraffic handles local traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
if blocked {
|
||||||
srcIP, dstIP)
|
_, pnum := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
|
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeDrop,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: pnum,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
// TODO: icmp type/code
|
||||||
|
RxPackets: 1,
|
||||||
|
RxBytes: uint64(size),
|
||||||
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -657,6 +714,9 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData
|
|||||||
return m.handleNetstackLocalTraffic(packetData)
|
return m.handleNetstackLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
|
m.trackInbound(d, srcIP, dstIP, ruleID, size)
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -666,12 +726,12 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder == nil {
|
if m.forwarder.Load() == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject local packet: %v", err)
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -681,30 +741,43 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleRoutedTraffic handles routed traffic.
|
// handleRoutedTraffic handles routed traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
||||||
// Drop if routing is disabled
|
// Drop if routing is disabled
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
srcIP, dstIP)
|
srcIP, dstIP)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass to native stack if native router is enabled or forced
|
// Pass to native stack if native router is enabled or forced
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := getProtocolFromPacket(d)
|
proto, pnum := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
srcIP, srcPort, dstIP, dstPort, proto)
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeDrop,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: pnum,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
// TODO: icmp type/code
|
||||||
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let forwarder handle the packet if it passed route ACLs
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
if err := m.forwarder.Load().InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
m.logger.Error("Failed to inject incoming packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -712,16 +785,16 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetDat
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
return firewall.ProtocolTCP
|
return firewall.ProtocolTCP, nftypes.TCP
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
return firewall.ProtocolUDP
|
return firewall.ProtocolUDP, nftypes.UDP
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return firewall.ProtocolICMP
|
return firewall.ProtocolICMP, nftypes.ICMP
|
||||||
default:
|
default:
|
||||||
return firewall.ProtocolALL
|
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -749,7 +822,7 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
return m.tcpTracker.IsValidInbound(
|
return m.tcpTracker.IsValidInbound(
|
||||||
@@ -758,6 +831,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
uint16(d.tcp.SrcPort),
|
uint16(d.tcp.SrcPort),
|
||||||
uint16(d.tcp.DstPort),
|
uint16(d.tcp.DstPort),
|
||||||
getTCPFlags(&d.tcp),
|
getTCPFlags(&d.tcp),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
@@ -766,6 +840,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
dstIP,
|
dstIP,
|
||||||
uint16(d.udp.SrcPort),
|
uint16(d.udp.SrcPort),
|
||||||
uint16(d.udp.DstPort),
|
uint16(d.udp.DstPort),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
@@ -773,8 +848,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
srcIP,
|
srcIP,
|
||||||
dstIP,
|
dstIP,
|
||||||
d.icmp4.Id,
|
d.icmp4.Id,
|
||||||
d.icmp4.Seq,
|
|
||||||
d.icmp4.TypeCode.Type(),
|
d.icmp4.TypeCode.Type(),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: ICMPv6
|
// TODO: ICMPv6
|
||||||
@@ -794,25 +869,27 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
|||||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
if m.isSpecialICMP(d) {
|
if m.isSpecialICMP(d) {
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default policy: DROP ALL
|
// Default policy: DROP ALL
|
||||||
return true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||||
@@ -832,15 +909,15 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
if rule.protoLayer == layerTypeAll {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if payloadLayer != rule.protoLayer {
|
||||||
@@ -850,39 +927,36 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
|
|||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
// we ignore rule.drop and call this hook
|
// we ignore rule.drop and call this hook
|
||||||
if rule.udpHook != nil {
|
if rule.udpHook != nil {
|
||||||
return rule.udpHook(packetData), true
|
return rule.mgmtId, rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
|
||||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
|
||||||
|
|
||||||
for _, rule := range m.routeRules {
|
for _, rule := range m.routeRules {
|
||||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||||
return rule.action == firewall.ActionAccept
|
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
@@ -922,36 +996,32 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
|
||||||
) string {
|
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
|
||||||
udpHook: hook,
|
udpHook: hook,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@@ -999,20 +1069,21 @@ func (m *Manager) DisableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.forwarder == nil {
|
fwder := m.forwarder.Load()
|
||||||
|
if fwder == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
// don't stop forwarder if in use by netstack
|
// don't stop forwarder if in use by netstack
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
m.forwarder = nil
|
m.forwarder.Store(nil)
|
||||||
|
|
||||||
log.Debug("forwarder stopped")
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
|||||||
@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: false,
|
stateful: false,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Single rule allowing all traffic
|
// Single rule allowing all traffic
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||||
fw.ActionAccept, "", "allow all")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Add explicit rules matching return traffic pattern
|
// Add explicit rules matching return traffic pattern
|
||||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
ip := generateRandomIPs(1)[0]
|
ip := generateRandomIPs(1)[0]
|
||||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
|
ip,
|
||||||
|
fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
&fw.Port{Values: []uint16{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
fw.ActionAccept, "", "explicit return")
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: true,
|
stateful: true,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Add some basic rules but rely on state for established connections
|
// Add some basic rules but rely on state for established connections
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
_, err := m.AddPeerFiltering(
|
||||||
fw.ActionDrop, "", "default drop")
|
nil,
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionDrop,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Connection tracking with established connections",
|
desc: "Connection tracking with established connections",
|
||||||
@@ -158,7 +169,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -203,7 +214,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.processOutgoingHooks(testOut)
|
manager.processOutgoingHooks(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn)
|
manager.dropFilter(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -251,7 +262,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -450,7 +461,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -577,7 +588,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -668,7 +676,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -787,7 +792,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -875,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(
|
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
r.sources,
|
|
||||||
r.dest,
|
|
||||||
r.proto,
|
|
||||||
nil,
|
|
||||||
r.port,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,9 +12,9 @@ import (
|
|||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
@@ -26,15 +26,15 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: localIP,
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
net.ParseIP(tc.ruleIP),
|
||||||
tc.ruleProto,
|
tc.ruleProto,
|
||||||
tc.ruleSrcPort,
|
tc.ruleSrcPort,
|
||||||
tc.ruleDstPort,
|
tc.ruleDstPort,
|
||||||
tc.ruleAction,
|
tc.ruleAction,
|
||||||
"",
|
"",
|
||||||
tc.name,
|
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules)
|
require.NotEmpty(t, rules)
|
||||||
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -288,8 +288,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: localIP,
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
@@ -302,12 +302,12 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled)
|
require.True(tb, manager.routingEnabled.Load())
|
||||||
require.False(tb, manager.nativeRouter)
|
require.False(tb, manager.nativeRouter.Load())
|
||||||
|
|
||||||
tb.Cleanup(func() {
|
tb.Cleanup(func() {
|
||||||
require.NoError(tb, manager.Close(nil))
|
require.NoError(tb, manager.Close(nil))
|
||||||
@@ -803,6 +803,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
tc.rule.dest,
|
tc.rule.dest,
|
||||||
tc.rule.proto,
|
tc.rule.proto,
|
||||||
@@ -817,12 +818,12 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
})
|
})
|
||||||
|
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -985,6 +986,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
var rules []fw.Rule
|
var rules []fw.Rule
|
||||||
for _, r := range tc.rules {
|
for _, r := range tc.rules {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
r.sources,
|
r.sources,
|
||||||
r.dest,
|
r.dest,
|
||||||
r.proto,
|
r.proto,
|
||||||
@@ -1004,10 +1006,10 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for i, p := range tc.packets {
|
for i, p := range tc.packets {
|
||||||
srcIP := net.ParseIP(p.srcIP)
|
srcIP := netip.MustParseAddr(p.srcIP)
|
||||||
dstIP := net.ParseIP(p.dstIP)
|
dstIP := netip.MustParseAddr(p.dstIP)
|
||||||
|
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,15 +18,17 @@ import (
|
|||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
GetWGDeviceFunc func() *wgdevice.Device
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
GetDeviceFunc func() *device.FilteredDevice
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
}
|
}
|
||||||
@@ -50,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
|||||||
return i.SetFilterFunc(iface)
|
return i.SetFilterFunc(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) Address() iface.WGAddress {
|
func (i *IFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc == nil {
|
if i.AddressFunc == nil {
|
||||||
return iface.WGAddress{}
|
return wgaddr.Address{}
|
||||||
}
|
}
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -92,9 +96,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -116,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
in bool
|
in bool
|
||||||
expDir fw.RuleDirection
|
expDir fw.RuleDirection
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
dPort uint16
|
dPort uint16
|
||||||
hook func([]byte) bool
|
hook func([]byte) bool
|
||||||
expectedID string
|
expectedID string
|
||||||
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Outgoing UDP Packet Hook",
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
in: false,
|
in: false,
|
||||||
expDir: fw.RuleDirectionOUT,
|
expDir: fw.RuleDirectionOUT,
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
ip: netip.MustParseAddr("10.168.0.1"),
|
||||||
dPort: 8000,
|
dPort: 8000,
|
||||||
hook: func([]byte) bool { return true },
|
hook: func([]byte) bool { return true },
|
||||||
},
|
},
|
||||||
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Incoming UDP Packet Hook",
|
name: "Test Incoming UDP Packet Hook",
|
||||||
in: true,
|
in: true,
|
||||||
expDir: fw.RuleDirectionIN,
|
expDir: fw.RuleDirectionIN,
|
||||||
ip: net.IPv6loopback,
|
ip: netip.MustParseAddr("::1"),
|
||||||
dPort: 9000,
|
dPort: 9000,
|
||||||
hook: func([]byte) bool { return false },
|
hook: func([]byte) bool { return false },
|
||||||
},
|
},
|
||||||
@@ -187,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule PeerRule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
for _, rule := range manager.incomingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -206,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -236,7 +238,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -246,9 +248,8 @@ func TestManagerReset(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -268,8 +269,8 @@ func TestManagerReset(t *testing.T) {
|
|||||||
func TestNotMatchByIP(t *testing.T) {
|
func TestNotMatchByIP(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
@@ -279,7 +280,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -292,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes()) {
|
if m.dropFilter(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -347,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface, false)
|
manager, err := Create(iface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -401,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
hookCalled := false
|
hookCalled := false
|
||||||
hookID := manager.AddUDPPacketHook(
|
hookID := manager.AddUDPPacketHook(
|
||||||
false,
|
false,
|
||||||
net.ParseIP("100.10.0.100"),
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
53,
|
53,
|
||||||
func([]byte) bool {
|
func([]byte) bool {
|
||||||
hookCalled = true
|
hookCalled = true
|
||||||
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.processOutgoingHooks(buf.Bytes())
|
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.processOutgoingHooks(buf.Bytes())
|
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -534,8 +534,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up packet parameters
|
// Set up packet parameters
|
||||||
srcIP := net.ParseIP("100.10.0.1")
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
dstIP := net.ParseIP("100.10.0.100")
|
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||||
srcPort := uint16(51334)
|
srcPort := uint16(51334)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
outboundIPv4 := &layers.IPv4{
|
outboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: srcIP,
|
SrcIP: srcIP.AsSlice(),
|
||||||
DstIP: dstIP,
|
DstIP: dstIP.AsSlice(),
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
outboundUDP := &layers.UDP{
|
outboundUDP := &layers.UDP{
|
||||||
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
inboundIPv4 := &layers.IPv4{
|
inboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: dstIP, // Original destination is now source
|
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||||
DstIP: srcIP, // Original source is now destination
|
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
inboundUDP := &layers.UDP{
|
inboundUDP := &layers.UDP{
|
||||||
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@@ -707,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes())
|
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import (
|
|||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
type RecvMessage struct {
|
||||||
@@ -51,9 +53,10 @@ type ICEBind struct {
|
|||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *UniversalUDPMuxDefault
|
udpMux *UniversalUDPMuxDefault
|
||||||
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
@@ -63,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
|||||||
endpoints: make(map[netip.Addr]net.Conn),
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
|
address: address,
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := receiverCreator{
|
rc := receiverCreator{
|
||||||
@@ -145,6 +149,7 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
UDPConn: conn,
|
UDPConn: conn,
|
||||||
Net: s.transportNet,
|
Net: s.transportNet,
|
||||||
FilterFn: s.filterFn,
|
FilterFn: s.filterFn,
|
||||||
|
WGAddress: s.address,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import (
|
|||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FilterFn is a function that filters out candidates based on the address.
|
// FilterFn is a function that filters out candidates based on the address.
|
||||||
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
|
|||||||
XORMappedAddrCacheTTL time.Duration
|
XORMappedAddrCacheTTL time.Duration
|
||||||
Net transport.Net
|
Net transport.Net
|
||||||
FilterFn FilterFn
|
FilterFn FilterFn
|
||||||
|
WGAddress wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||||
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
mux: m,
|
mux: m,
|
||||||
logger: params.Logger,
|
logger: params.Logger,
|
||||||
filterFn: params.FilterFn,
|
filterFn: params.FilterFn,
|
||||||
|
address: params.WGAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
// embed UDPMux
|
// embed UDPMux
|
||||||
@@ -118,6 +122,7 @@ type udpConn struct {
|
|||||||
filterFn FilterFn
|
filterFn FilterFn
|
||||||
// TODO: reset cache on route changes
|
// TODO: reset cache on route changes
|
||||||
addrCache sync.Map
|
addrCache sync.Map
|
||||||
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if u.address.Network.Contains(a.AsSlice()) {
|
||||||
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
|
}
|
||||||
|
|
||||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create() (device.WGConfigurer, error)
|
Create() (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address WGAddress) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() WGAddress
|
WgAddress() wgaddr.Address
|
||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||||
type WGTunDevice struct {
|
type WGTunDevice struct {
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -31,7 +32,7 @@ type WGTunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -93,7 +94,7 @@ func (t *WGTunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) UpdateAddr(addr WGAddress) error {
|
func (t *WGTunDevice) UpdateAddr(addr wgaddr.Address) error {
|
||||||
// todo implement
|
// todo implement
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -123,7 +124,7 @@ func (t *WGTunDevice) DeviceName() string {
|
|||||||
return t.name
|
return t.name
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) WgAddress() WGAddress {
|
func (t *WGTunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,11 +13,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunDevice struct {
|
type TunDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -29,7 +30,7 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -85,7 +86,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -106,7 +107,7 @@ func (t *TunDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) WgAddress() WGAddress {
|
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -10,16 +11,16 @@ import (
|
|||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// DropOutgoing filter outgoing packets from host to external destinations
|
// DropOutgoing filter outgoing packets from host to external destinations
|
||||||
DropOutgoing(packetData []byte) bool
|
DropOutgoing(packetData []byte, size int) bool
|
||||||
|
|
||||||
// DropIncoming filter incoming packets from external sources to host
|
// DropIncoming filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte) bool
|
DropIncoming(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||||
// Hook function receives raw network packet data as argument.
|
// Hook function receives raw network packet data as argument.
|
||||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
|
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.DropIncoming(buf[offset:]) {
|
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
|||||||
@@ -14,11 +14,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunDevice struct {
|
type TunDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
@@ -30,7 +31,7 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
func NewTunDevice(name string, address wgaddr.Address, port int, key string, iceBind *bind.ICEBind, tunFd int) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -120,11 +121,11 @@ func (t *TunDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) WgAddress() WGAddress {
|
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) UpdateAddr(addr WGAddress) error {
|
func (t *TunDevice) UpdateAddr(_ wgaddr.Address) error {
|
||||||
// todo implement
|
// todo implement
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,12 +14,13 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/sharedsock"
|
"github.com/netbirdio/netbird/sharedsock"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunKernelDevice struct {
|
type TunKernelDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
wgPort int
|
wgPort int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -34,7 +35,7 @@ type TunKernelDevice struct {
|
|||||||
filterFn bind.FilterFn
|
filterFn bind.FilterFn
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewKernelDevice(name string, address WGAddress, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
func NewKernelDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, transportNet transport.Net) *TunKernelDevice {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
return &TunKernelDevice{
|
return &TunKernelDevice{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
@@ -102,6 +103,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
UDPConn: rawSock,
|
UDPConn: rawSock,
|
||||||
Net: t.transportNet,
|
Net: t.transportNet,
|
||||||
FilterFn: t.filterFn,
|
FilterFn: t.filterFn,
|
||||||
|
WGAddress: t.address,
|
||||||
}
|
}
|
||||||
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
mux := bind.NewUniversalUDPMuxDefault(bindParams)
|
||||||
go mux.ReadFromConn(t.ctx)
|
go mux.ReadFromConn(t.ctx)
|
||||||
@@ -112,7 +114,7 @@ func (t *TunKernelDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return t.udpMux, nil
|
return t.udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) UpdateAddr(address WGAddress) error {
|
func (t *TunKernelDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -145,7 +147,7 @@ func (t *TunKernelDevice) Close() error {
|
|||||||
return closErr
|
return closErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunKernelDevice) WgAddress() WGAddress {
|
func (t *TunKernelDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TunNetstackDevice struct {
|
type TunNetstackDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -34,7 +35,7 @@ type TunNetstackDevice struct {
|
|||||||
net *netstack.Net
|
net *netstack.Net
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetstackDevice(name string, address WGAddress, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
func NewNetstackDevice(name string, address wgaddr.Address, wgPort int, key string, mtu int, iceBind *bind.ICEBind, listenAddress string) *TunNetstackDevice {
|
||||||
return &TunNetstackDevice{
|
return &TunNetstackDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -97,7 +98,7 @@ func (t *TunNetstackDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunNetstackDevice) UpdateAddr(WGAddress) error {
|
func (t *TunNetstackDevice) UpdateAddr(wgaddr.Address) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -116,7 +117,7 @@ func (t *TunNetstackDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunNetstackDevice) WgAddress() WGAddress {
|
func (t *TunNetstackDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,11 +12,12 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type USPDevice struct {
|
type USPDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -28,7 +29,7 @@ type USPDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
func NewUSPDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *USPDevice {
|
||||||
log.Infof("using userspace bind mode")
|
log.Infof("using userspace bind mode")
|
||||||
|
|
||||||
return &USPDevice{
|
return &USPDevice{
|
||||||
@@ -93,7 +94,7 @@ func (t *USPDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *USPDevice) UpdateAddr(address WGAddress) error {
|
func (t *USPDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -113,7 +114,7 @@ func (t *USPDevice) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *USPDevice) WgAddress() WGAddress {
|
func (t *USPDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,13 +13,14 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
|
const defaultWindowsGUIDSTring = "{f2f29e61-d91f-4d76-8151-119b20c4bdeb}"
|
||||||
|
|
||||||
type TunDevice struct {
|
type TunDevice struct {
|
||||||
name string
|
name string
|
||||||
address WGAddress
|
address wgaddr.Address
|
||||||
port int
|
port int
|
||||||
key string
|
key string
|
||||||
mtu int
|
mtu int
|
||||||
@@ -32,7 +33,7 @@ type TunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(name string, address WGAddress, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
func NewTunDevice(name string, address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind) *TunDevice {
|
||||||
return &TunDevice{
|
return &TunDevice{
|
||||||
name: name,
|
name: name,
|
||||||
address: address,
|
address: address,
|
||||||
@@ -118,7 +119,7 @@ func (t *TunDevice) Up() (*bind.UniversalUDPMuxDefault, error) {
|
|||||||
return udpMux, nil
|
return udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TunDevice) UpdateAddr(address WGAddress) error {
|
func (t *TunDevice) UpdateAddr(address wgaddr.Address) error {
|
||||||
t.address = address
|
t.address = address
|
||||||
return t.assignAddr()
|
return t.assignAddr()
|
||||||
}
|
}
|
||||||
@@ -139,7 +140,7 @@ func (t *TunDevice) Close() error {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (t *TunDevice) WgAddress() WGAddress {
|
func (t *TunDevice) WgAddress() wgaddr.Address {
|
||||||
return t.address
|
return t.address
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/freebsd"
|
"github.com/netbirdio/netbird/client/iface/freebsd"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgLink struct {
|
type wgLink struct {
|
||||||
@@ -56,7 +57,7 @@ func (l *wgLink) up() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||||
link, err := freebsd.LinkByName(l.name)
|
link, err := freebsd.LinkByName(l.name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("link by name: %w", err)
|
return fmt.Errorf("link by name: %w", err)
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/vishvananda/netlink"
|
"github.com/vishvananda/netlink"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type wgLink struct {
|
type wgLink struct {
|
||||||
@@ -90,7 +92,7 @@ func (l *wgLink) up() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *wgLink) assignAddr(address WGAddress) error {
|
func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
||||||
//delete existing addresses
|
//delete existing addresses
|
||||||
list, err := netlink.AddrList(l, 0)
|
list, err := netlink.AddrList(l, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -7,13 +7,14 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address WGAddress) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() WGAddress
|
WgAddress() wgaddr.Address
|
||||||
DeviceName() string
|
DeviceName() string
|
||||||
Close() error
|
Close() error
|
||||||
FilteredDevice() *device.FilteredDevice
|
FilteredDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -28,8 +29,6 @@ const (
|
|||||||
WgInterfaceDefault = configurer.WgInterfaceDefault
|
WgInterfaceDefault = configurer.WgInterfaceDefault
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGAddress = device.WGAddress
|
|
||||||
|
|
||||||
type wgProxyFactory interface {
|
type wgProxyFactory interface {
|
||||||
GetProxy() wgproxy.Proxy
|
GetProxy() wgproxy.Proxy
|
||||||
Free() error
|
Free() error
|
||||||
@@ -72,7 +71,7 @@ func (w *WGIface) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Address returns the interface address
|
// Address returns the interface address
|
||||||
func (w *WGIface) Address() device.WGAddress {
|
func (w *WGIface) Address() wgaddr.Address {
|
||||||
return w.tun.WgAddress()
|
return w.tun.WgAddress()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -103,7 +102,7 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
addr, err := device.ParseWGAddress(newAddr)
|
addr, err := wgaddr.ParseWGAddress(newAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,17 +3,18 @@ package iface
|
|||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
|
|||||||
@@ -6,17 +6,18 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
var tun WGTunDevice
|
var tun WGTunDevice
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
|
|||||||
@@ -5,17 +5,18 @@ package iface
|
|||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
tun: device.NewTunDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, iceBind, opts.MobileArgs.TunFd),
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -21,7 +22,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
wgIFace := &WGIface{}
|
wgIFace := &WGIface{}
|
||||||
|
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
wgIFace.tun = device.NewNetstackDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, netstack.ListenAddr())
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||||
@@ -34,7 +35,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
}
|
}
|
||||||
if device.ModuleTunIsLoaded() {
|
if device.ModuleTunIsLoaded() {
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
wgIFace.tun = device.NewUSPDevice(opts.IFaceName, wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind)
|
||||||
wgIFace.userspaceBind = true
|
wgIFace.userspaceBind = true
|
||||||
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
wgIFace.wgProxyFactory = wgproxy.NewUSPFactory(iceBind)
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
wgaddr "github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewWGIFace Creates a new WireGuard interface instance
|
// NewWGIFace Creates a new WireGuard interface instance
|
||||||
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||||
wgAddress, err := device.ParseWGAddress(opts.Address)
|
wgAddress, err := wgaddr.ParseWGAddress(opts.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn)
|
iceBind := bind.NewICEBind(opts.TransportNet, opts.FilterFn, wgAddress)
|
||||||
|
|
||||||
var tun WGTunDevice
|
var tun WGTunDevice
|
||||||
if netstack.IsEnabled() {
|
if netstack.IsEnabled() {
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package mocks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
net "net"
|
||||||
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook mocks base method.
|
// AddUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(string)
|
ret0, _ := ret[0].(string)
|
||||||
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// DropIncoming mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// DropOutgoing mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
|
|||||||
@@ -1,29 +1,29 @@
|
|||||||
package device
|
package wgaddr
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGAddress WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type WGAddress struct {
|
type Address struct {
|
||||||
IP net.IP
|
IP net.IP
|
||||||
Network *net.IPNet
|
Network *net.IPNet
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
func ParseWGAddress(address string) (WGAddress, error) {
|
func ParseWGAddress(address string) (Address, error) {
|
||||||
ip, network, err := net.ParseCIDR(address)
|
ip, network, err := net.ParseCIDR(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGAddress{}, err
|
return Address{}, err
|
||||||
}
|
}
|
||||||
return WGAddress{
|
return Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (addr WGAddress) String() string {
|
func (addr Address) String() string {
|
||||||
maskSize, _ := addr.Network.Mask.Size()
|
maskSize, _ := addr.Network.Mask.Size()
|
||||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
||||||
}
|
}
|
||||||
@@ -22,6 +22,8 @@
|
|||||||
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
|
!define UI_REG_APP_PATH "Software\Microsoft\Windows\CurrentVersion\App Paths\${UI_APP_EXE}"
|
||||||
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
|
!define UI_UNINSTALL_PATH "Software\Microsoft\Windows\CurrentVersion\Uninstall\${UI_APP_NAME}"
|
||||||
|
|
||||||
|
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||||
|
|
||||||
Unicode True
|
Unicode True
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
@@ -68,6 +70,9 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_PAGE_DIRECTORY
|
!insertmacro MUI_PAGE_DIRECTORY
|
||||||
|
|
||||||
|
; Custom page for autostart checkbox
|
||||||
|
Page custom AutostartPage AutostartPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_INSTFILES
|
!insertmacro MUI_PAGE_INSTFILES
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_FINISH
|
!insertmacro MUI_PAGE_FINISH
|
||||||
@@ -80,8 +85,36 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_LANGUAGE "English"
|
!insertmacro MUI_LANGUAGE "English"
|
||||||
|
|
||||||
|
; Variables for autostart option
|
||||||
|
Var AutostartCheckbox
|
||||||
|
Var AutostartEnabled
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
; Function to create the autostart options page
|
||||||
|
Function AutostartPage
|
||||||
|
!insertmacro MUI_HEADER_TEXT "Startup Options" "Configure how ${APP_NAME} launches with Windows."
|
||||||
|
|
||||||
|
nsDialogs::Create 1018
|
||||||
|
Pop $0
|
||||||
|
|
||||||
|
${If} $0 == error
|
||||||
|
Abort
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||||
|
Pop $AutostartCheckbox
|
||||||
|
${NSD_Check} $AutostartCheckbox ; Default to checked
|
||||||
|
StrCpy $AutostartEnabled "1" ; Default to enabled
|
||||||
|
|
||||||
|
nsDialogs::Show
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to handle leaving the autostart page
|
||||||
|
Function AutostartPageLeave
|
||||||
|
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
Function GetAppFromCommand
|
Function GetAppFromCommand
|
||||||
Exch $1
|
Exch $1
|
||||||
Push $2
|
Push $2
|
||||||
@@ -163,6 +196,16 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
|
|||||||
|
|
||||||
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
||||||
|
|
||||||
|
; Create autostart registry entry based on checkbox
|
||||||
|
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||||
|
${If} $AutostartEnabled == "1"
|
||||||
|
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||||
|
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||||
|
${Else}
|
||||||
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
DetailPrint "Autostart not enabled by user"
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::AddValueEx "path" "$INSTDIR"
|
EnVar::AddValueEx "path" "$INSTDIR"
|
||||||
|
|
||||||
@@ -186,7 +229,10 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
|||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||||
|
|
||||||
# kill ui client
|
# kill ui client
|
||||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe`
|
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||||
|
|
||||||
|
; Remove autostart registry entry
|
||||||
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
# wait the service uninstall take unblock the executable
|
# wait the service uninstall take unblock the executable
|
||||||
Sleep 3000
|
Sleep 3000
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
type RuleID string
|
type RuleID string
|
||||||
|
|
||||||
func (r RuleID) GetRuleID() string {
|
func (r RuleID) ID() string {
|
||||||
return string(r)
|
return string(r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -240,12 +240,12 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
|
|
||||||
dPorts := convertPortInfo(rule.PortInfo)
|
dPorts := convertPortInfo(rule.PortInfo)
|
||||||
|
|
||||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
addedRule, err := d.firewall.AddRouteFiltering(rule.Id, sources, destination, protocol, nil, dPorts, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("add route rule: %w", err)
|
return "", fmt.Errorf("add route rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return id.RuleID(addedRule.GetRuleID()), nil
|
return id.RuleID(addedRule.ID()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) protoRuleToFirewallRule(
|
func (d *DefaultManager) protoRuleToFirewallRule(
|
||||||
@@ -281,7 +281,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||||
return ruleID, rulesPair, nil
|
return ruleID, rulesPair, nil
|
||||||
}
|
}
|
||||||
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(r.Id, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addOutRules(r.Id, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
@@ -322,14 +322,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
func (d *DefaultManager) addInRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -338,18 +338,18 @@ func (d *DefaultManager) addInRules(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -364,9 +364,8 @@ func (d *DefaultManager) getPeerRuleID(
|
|||||||
direction int,
|
direction int,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
comment string,
|
|
||||||
) id.RuleID {
|
) id.RuleID {
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||||
if port != nil {
|
if port != nil {
|
||||||
idStr += port.String()
|
idStr += port.String()
|
||||||
}
|
}
|
||||||
@@ -515,7 +514,7 @@ func (d *DefaultManager) rollBack(newRulePairs map[id.RuleID][]firewall.Rule) {
|
|||||||
for _, rules := range newRulePairs {
|
for _, rules := range newRulePairs {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
if err := d.firewall.DeletePeerRule(rule); err != nil {
|
||||||
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.GetRuleID(), err)
|
log.Errorf("failed to delete new firewall rule (id: %v) during rollback: %v", rule.ID(), err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -8,11 +9,14 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
@@ -45,14 +49,14 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@@ -74,7 +78,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
t.Run("add extra rules", func(t *testing.T) {
|
t.Run("add extra rules", func(t *testing.T) {
|
||||||
existedPairs := map[string]struct{}{}
|
existedPairs := map[string]struct{}{}
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
existedPairs[id.GetRuleID()] = struct{}{}
|
existedPairs[id.ID()] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove first rule
|
// remove first rule
|
||||||
@@ -100,7 +104,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
if _, ok := existedPairs[id.GetRuleID()]; ok {
|
if _, ok := existedPairs[id.ID()]; ok {
|
||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -339,14 +343,14 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(iface.WGAddress{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
iface "github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockIFaceMapper is a mock of IFaceMapper interface.
|
// MockIFaceMapper is a mock of IFaceMapper interface.
|
||||||
@@ -38,10 +38,10 @@ func (m *MockIFaceMapper) EXPECT() *MockIFaceMapperMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Address mocks base method.
|
// Address mocks base method.
|
||||||
func (m *MockIFaceMapper) Address() iface.WGAddress {
|
func (m *MockIFaceMapper) Address() wgaddr.Address {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Address")
|
ret := m.ctrl.Call(m, "Address")
|
||||||
ret0, _ := ret[0].(iface.WGAddress)
|
ret0, _ := ret[0].(wgaddr.Address)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func NewConnectClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run(runningChan chan error) error {
|
func (c *ConnectClient) Run(runningChan chan struct{}) error {
|
||||||
return c.run(MobileDependency{}, runningChan)
|
return c.run(MobileDependency{}, runningChan)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -102,7 +102,7 @@ func (c *ConnectClient) RunOniOS(
|
|||||||
return c.run(mobileDependency, nil)
|
return c.run(mobileDependency, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan error) error {
|
func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan struct{}) error {
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
rec := c.statusRecorder
|
rec := c.statusRecorder
|
||||||
@@ -159,10 +159,9 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
runningChanOpen := true
|
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
if c.isContextCancelled() {
|
if c.ctx.Err() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -282,10 +281,11 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
if runningChan != nil && runningChanOpen {
|
if runningChan != nil {
|
||||||
runningChan <- nil
|
select {
|
||||||
close(runningChan)
|
case runningChan <- struct{}{}:
|
||||||
runningChanOpen = false
|
default:
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
@@ -379,15 +379,6 @@ func (c *ConnectClient) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ConnectClient) isContextCancelled() bool {
|
|
||||||
select {
|
|
||||||
case <-c.ctx.Done():
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetworkMapPersistence enables or disables network map persistence.
|
// SetNetworkMapPersistence enables or disables network map persistence.
|
||||||
// When enabled, the last received network map will be stored and can be retrieved
|
// When enabled, the last received network map will be stored and can be retrieved
|
||||||
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
// through the Engine's getLatestNetworkMap method. When disabled, any stored
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
@@ -29,6 +31,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type mocWGIface struct {
|
type mocWGIface struct {
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
}
|
}
|
||||||
@@ -37,9 +41,9 @@ func (w *mocWGIface) Name() string {
|
|||||||
panic("implement me")
|
panic("implement me")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Address() iface.WGAddress {
|
func (w *mocWGIface) Address() wgaddr.Address {
|
||||||
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
ip, network, _ := net.ParseCIDR("100.66.100.0/24")
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: ip,
|
IP: ip,
|
||||||
Network: network,
|
Network: network,
|
||||||
}
|
}
|
||||||
@@ -455,7 +459,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
@@ -916,7 +920,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface, false)
|
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,15 +5,15 @@ package dns
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIface defines subset methods of interface required for manager
|
// WGIface defines subset methods of interface required for manager
|
||||||
type WGIface interface {
|
type WGIface interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
|
|||||||
@@ -1,15 +1,15 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIface defines subset methods of interface required for manager
|
// WGIface defines subset methods of interface required for manager
|
||||||
type WGIface interface {
|
type WGIface interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
GetFilter() device.PacketFilter
|
GetFilter() device.PacketFilter
|
||||||
GetDevice() *device.FilteredDevice
|
GetDevice() *device.FilteredDevice
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ func (h *Manager) allowDNSFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
@@ -33,6 +33,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
@@ -169,10 +172,11 @@ type Engine struct {
|
|||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
firewall manager.Manager
|
firewall firewallManager.Manager
|
||||||
routeManager routemanager.Manager
|
routeManager routemanager.Manager
|
||||||
acl acl.Manager
|
acl acl.Manager
|
||||||
dnsForwardMgr *dnsfwd.Manager
|
dnsForwardMgr *dnsfwd.Manager
|
||||||
|
ingressGatewayMgr *ingressgw.Manager
|
||||||
|
|
||||||
dnsServer dns.Server
|
dnsServer dns.Server
|
||||||
|
|
||||||
@@ -187,6 +191,7 @@ type Engine struct {
|
|||||||
persistNetworkMap bool
|
persistNetworkMap bool
|
||||||
latestNetworkMap *mgmProto.NetworkMap
|
latestNetworkMap *mgmProto.NetworkMap
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
|
flowManager nftypes.FlowManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -266,6 +271,13 @@ func (e *Engine) Stop() error {
|
|||||||
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
// stop/restore DNS first so dbus and friends don't complain because of a missing interface
|
||||||
e.stopDNSServer()
|
e.stopDNSServer()
|
||||||
|
|
||||||
|
if e.ingressGatewayMgr != nil {
|
||||||
|
if err := e.ingressGatewayMgr.Close(); err != nil {
|
||||||
|
log.Warnf("failed to cleanup forward rules: %v", err)
|
||||||
|
}
|
||||||
|
e.ingressGatewayMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
if e.routeManager != nil {
|
if e.routeManager != nil {
|
||||||
e.routeManager.Stop(e.stateManager)
|
e.routeManager.Stop(e.stateManager)
|
||||||
}
|
}
|
||||||
@@ -299,6 +311,12 @@ func (e *Engine) Stop() error {
|
|||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
|
// stop flow manager after wg interface is gone
|
||||||
|
if e.flowManager != nil {
|
||||||
|
e.flowManager.Close()
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
@@ -333,6 +351,10 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.wgInterface = wgIface
|
e.wgInterface = wgIface
|
||||||
|
|
||||||
|
// start flow manager right after interface creation
|
||||||
|
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||||
|
e.flowManager = netflow.NewManager(e.ctx, e.wgInterface, publicKey[:], e.statusRecorder)
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled {
|
||||||
log.Infof("rosenpass is enabled")
|
log.Infof("rosenpass is enabled")
|
||||||
if e.config.RosenpassPermissive {
|
if e.config.RosenpassPermissive {
|
||||||
@@ -439,7 +461,7 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil || e.firewall == nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -469,16 +491,16 @@ func (e *Engine) initFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rosenpassPort := e.rpManager.GetAddress().Port
|
rosenpassPort := e.rpManager.GetAddress().Port
|
||||||
port := manager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
port := firewallManager.Port{Values: []uint16{uint16(rosenpassPort)}}
|
||||||
|
|
||||||
// this rule is static and will be torn down on engine down by the firewall manager
|
// this rule is static and will be torn down on engine down by the firewall manager
|
||||||
if _, err := e.firewall.AddPeerFiltering(
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
manager.ProtocolUDP,
|
firewallManager.ProtocolUDP,
|
||||||
nil,
|
nil,
|
||||||
&port,
|
&port,
|
||||||
manager.ActionAccept,
|
firewallManager.ActionAccept,
|
||||||
"",
|
|
||||||
"",
|
"",
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||||
@@ -503,12 +525,13 @@ func (e *Engine) blockLanAccess() {
|
|||||||
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
for _, network := range toBlock {
|
for _, network := range toBlock {
|
||||||
if _, err := e.firewall.AddRouteFiltering(
|
if _, err := e.firewall.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{v4},
|
[]netip.Prefix{v4},
|
||||||
network,
|
network,
|
||||||
manager.ProtocolALL,
|
firewallManager.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
manager.ActionDrop,
|
firewallManager.ActionDrop,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
|
merr = multierror.Append(merr, fmt.Errorf("add fw rule for network %s: %w", network, err))
|
||||||
}
|
}
|
||||||
@@ -633,25 +656,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
e.stunTurn.Store(stunTurn)
|
e.stunTurn.Store(stunTurn)
|
||||||
|
|
||||||
relayMsg := wCfg.GetRelay()
|
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||||
if relayMsg != nil {
|
if err != nil {
|
||||||
// when we receive token we expect valid address list too
|
return err
|
||||||
c := &auth.Token{
|
|
||||||
Payload: relayMsg.GetTokenPayload(),
|
|
||||||
Signature: relayMsg.GetTokenSignature(),
|
|
||||||
}
|
|
||||||
if err := e.relayManager.UpdateToken(c); err != nil {
|
|
||||||
log.Errorf("failed to update relay token: %v", err)
|
|
||||||
return fmt.Errorf("update relay token: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||||
|
if err != nil {
|
||||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
|
||||||
_ = e.relayManager.Serve()
|
|
||||||
} else {
|
|
||||||
e.relayManager.UpdateServerURLs(nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo update signal
|
// todo update signal
|
||||||
@@ -682,6 +694,55 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||||
|
if update != nil {
|
||||||
|
// when we receive token we expect valid address list too
|
||||||
|
c := &auth.Token{
|
||||||
|
Payload: update.GetTokenPayload(),
|
||||||
|
Signature: update.GetTokenSignature(),
|
||||||
|
}
|
||||||
|
if err := e.relayManager.UpdateToken(c); err != nil {
|
||||||
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.relayManager.UpdateServerURLs(update.Urls)
|
||||||
|
|
||||||
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
|
_ = e.relayManager.Serve()
|
||||||
|
} else {
|
||||||
|
e.relayManager.UpdateServerURLs(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flowConfig, err := toFlowLoggerConfig(config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return e.flowManager.Update(flowConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
||||||
|
if config.GetInterval() == nil {
|
||||||
|
return nil, errors.New("flow interval is nil")
|
||||||
|
}
|
||||||
|
return &nftypes.FlowConfig{
|
||||||
|
Enabled: config.GetEnabled(),
|
||||||
|
Counters: config.GetCounters(),
|
||||||
|
URL: config.GetUrl(),
|
||||||
|
TokenPayload: config.GetTokenPayload(),
|
||||||
|
TokenSignature: config.GetTokenSignature(),
|
||||||
|
Interval: config.GetInterval().AsDuration(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
||||||
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||||
// if checks are equal, we skip the update
|
// if checks are equal, we skip the update
|
||||||
@@ -912,6 +973,11 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Ingress forward rules
|
||||||
|
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
||||||
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers()))
|
||||||
|
|
||||||
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
e.updateOfflinePeers(networkMap.GetOfflinePeers())
|
||||||
@@ -1482,7 +1548,7 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetFirewallManager returns the firewall manager
|
// GetFirewallManager returns the firewall manager
|
||||||
func (e *Engine) GetFirewallManager() manager.Manager {
|
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||||
return e.firewall
|
return e.firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1575,16 +1641,19 @@ func (e *Engine) probeTURNs() []relay.ProbeResult {
|
|||||||
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
return relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// restartEngine restarts the engine by cancelling the client context
|
||||||
func (e *Engine) restartEngine() {
|
func (e *Engine) restartEngine() {
|
||||||
log.Info("restarting engine")
|
e.syncMsgMux.Lock()
|
||||||
CtxGetState(e.ctx).Set(StatusConnecting)
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
if err := e.Stop(); err != nil {
|
if e.ctx.Err() != nil {
|
||||||
log.Errorf("Failed to stop engine: %v", err)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("restarting engine")
|
||||||
|
CtxGetState(e.ctx).Set(StatusConnecting)
|
||||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
log.Infof("cancelling client, engine will be recreated")
|
log.Infof("cancelling client context, engine will be recreated")
|
||||||
e.clientCancel()
|
e.clientCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1596,34 +1665,17 @@ func (e *Engine) startNetworkMonitor() {
|
|||||||
|
|
||||||
e.networkMonitor = networkmonitor.New()
|
e.networkMonitor = networkmonitor.New()
|
||||||
go func() {
|
go func() {
|
||||||
var mu sync.Mutex
|
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||||
var debounceTimer *time.Timer
|
if errors.Is(err, context.Canceled) {
|
||||||
|
log.Infof("network monitor stopped")
|
||||||
// Start the network monitor with a callback, Start will block until the monitor is stopped,
|
return
|
||||||
// a network change is detected, or an error occurs on start up
|
}
|
||||||
err := e.networkMonitor.Start(e.ctx, func() {
|
log.Errorf("network monitor error: %v", err)
|
||||||
// This function is called when a network change is detected
|
return
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
if debounceTimer != nil {
|
|
||||||
log.Infof("Network monitor: detected network change, reset debounceTimer")
|
|
||||||
debounceTimer.Stop()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set a new timer to debounce rapid network changes
|
|
||||||
debounceTimer = time.AfterFunc(2*time.Second, func() {
|
|
||||||
// This function is called after the debounce period
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
log.Infof("Network monitor: detected network change, restarting engine")
|
log.Infof("Network monitor: detected network change, restarting engine")
|
||||||
e.restartEngine()
|
e.restartEngine()
|
||||||
})
|
|
||||||
})
|
|
||||||
if err != nil && !errors.Is(err, networkmonitor.ErrStopped) {
|
|
||||||
log.Errorf("Network monitor: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1770,6 +1822,74 @@ func (e *Engine) Address() (netip.Addr, error) {
|
|||||||
return ip.Unmap(), nil
|
return ip.Unmap(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) error {
|
||||||
|
if e.firewall == nil {
|
||||||
|
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rules) == 0 {
|
||||||
|
if e.ingressGatewayMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
err := e.ingressGatewayMgr.Close()
|
||||||
|
e.ingressGatewayMgr = nil
|
||||||
|
e.statusRecorder.SetIngressGwMgr(nil)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.ingressGatewayMgr == nil {
|
||||||
|
mgr := ingressgw.NewManager(e.firewall)
|
||||||
|
e.ingressGatewayMgr = mgr
|
||||||
|
e.statusRecorder.SetIngressGwMgr(mgr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
forwardingRules := make([]firewallManager.ForwardRule, 0, len(rules))
|
||||||
|
for _, rule := range rules {
|
||||||
|
proto, err := convertToFirewallProtocol(rule.GetProtocol())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to convert protocol '%s': %w", rule.GetProtocol(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
dstPortInfo, err := convertPortInfo(rule.GetDestinationPort())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("invalid destination port '%v': %w", rule.GetDestinationPort(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
translateIP, err := convertToIP(rule.GetTranslatedAddress())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to convert translated address '%s': %w", rule.GetTranslatedAddress(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
translatePort, err := convertPortInfo(rule.GetTranslatedPort())
|
||||||
|
if err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("invalid translate port '%v': %w", rule.GetTranslatedPort(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
forwardRule := firewallManager.ForwardRule{
|
||||||
|
Protocol: proto,
|
||||||
|
DestinationPort: *dstPortInfo,
|
||||||
|
TranslatedAddress: translateIP,
|
||||||
|
TranslatedPort: *translatePort,
|
||||||
|
}
|
||||||
|
|
||||||
|
forwardingRules = append(forwardingRules, forwardRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("updating forwarding rules: %d", len(forwardingRules))
|
||||||
|
if err := e.ingressGatewayMgr.Update(forwardingRules); err != nil {
|
||||||
|
log.Errorf("failed to update forwarding rules: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
// isChecksEqual checks if two slices of checks are equal.
|
// isChecksEqual checks if two slices of checks are equal.
|
||||||
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
func isChecksEqual(checks []*mgmProto.Checks, oChecks []*mgmProto.Checks) bool {
|
||||||
for _, check := range checks {
|
for _, check := range checks {
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
@@ -44,6 +45,7 @@ import (
|
|||||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -74,7 +76,7 @@ type MockWGIface struct {
|
|||||||
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
CreateOnAndroidFunc func(routeRange []string, ip string, domains []string) error
|
||||||
IsUserspaceBindFunc func() bool
|
IsUserspaceBindFunc func() bool
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() device.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
ToInterfaceFunc func() *net.Interface
|
ToInterfaceFunc func() *net.Interface
|
||||||
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
UpFunc func() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddrFunc func(newAddr string) error
|
UpdateAddrFunc func(newAddr string) error
|
||||||
@@ -113,7 +115,7 @@ func (m *MockWGIface) Name() string {
|
|||||||
return m.NameFunc()
|
return m.NameFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockWGIface) Address() device.WGAddress {
|
func (m *MockWGIface) Address() wgaddr.Address {
|
||||||
return m.AddressFunc()
|
return m.AddressFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -363,8 +365,8 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
RemovePeerFunc: func(peerKey string) error {
|
RemovePeerFunc: func(peerKey string) error {
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -1433,13 +1435,13 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics)
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,7 +21,7 @@ type wgIfaceBase interface {
|
|||||||
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
CreateOnAndroid(routeRange []string, ip string, domains []string) error
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
Name() string
|
Name() string
|
||||||
Address() device.WGAddress
|
Address() wgaddr.Address
|
||||||
ToInterface() *net.Interface
|
ToInterface() *net.Interface
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(newAddr string) error
|
UpdateAddr(newAddr string) error
|
||||||
|
|||||||
107
client/internal/ingressgw/manager.go
Normal file
107
client/internal/ingressgw/manager.go
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
package ingressgw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DNATFirewall interface {
|
||||||
|
AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error)
|
||||||
|
DeleteDNATRule(rule firewall.Rule) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type RulePair struct {
|
||||||
|
firewall.ForwardRule
|
||||||
|
firewall.Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
dnatFirewall DNATFirewall
|
||||||
|
|
||||||
|
rules map[string]RulePair // keys is the ID of the ForwardRule
|
||||||
|
rulesMu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(dnatFirewall DNATFirewall) *Manager {
|
||||||
|
return &Manager{
|
||||||
|
dnatFirewall: dnatFirewall,
|
||||||
|
rules: make(map[string]RulePair),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) Update(forwardRules []firewall.ForwardRule) error {
|
||||||
|
h.rulesMu.Lock()
|
||||||
|
defer h.rulesMu.Unlock()
|
||||||
|
|
||||||
|
var mErr *multierror.Error
|
||||||
|
|
||||||
|
toDelete := make(map[string]RulePair, len(h.rules))
|
||||||
|
for id, r := range h.rules {
|
||||||
|
toDelete[id] = r
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process new/updated rules
|
||||||
|
for _, fwdRule := range forwardRules {
|
||||||
|
id := fwdRule.ID()
|
||||||
|
if _, ok := h.rules[id]; ok {
|
||||||
|
delete(toDelete, id)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := h.dnatFirewall.AddDNATRule(fwdRule)
|
||||||
|
if err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("add forward rule '%s': %v", fwdRule.String(), err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Infof("forward rule has been added '%s'", fwdRule)
|
||||||
|
h.rules[id] = RulePair{
|
||||||
|
ForwardRule: fwdRule,
|
||||||
|
Rule: rule,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove deleted rules
|
||||||
|
for id, rulePair := range toDelete {
|
||||||
|
if err := h.dnatFirewall.DeleteDNATRule(rulePair.Rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rulePair.ForwardRule.String(), err))
|
||||||
|
}
|
||||||
|
log.Infof("forward rule has been deleted '%s'", rulePair.ForwardRule)
|
||||||
|
delete(h.rules, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) Close() error {
|
||||||
|
h.rulesMu.Lock()
|
||||||
|
defer h.rulesMu.Unlock()
|
||||||
|
|
||||||
|
log.Infof("clean up all (%d) forward rules", len(h.rules))
|
||||||
|
var mErr *multierror.Error
|
||||||
|
for _, rule := range h.rules {
|
||||||
|
if err := h.dnatFirewall.DeleteDNATRule(rule.Rule); err != nil {
|
||||||
|
mErr = multierror.Append(mErr, fmt.Errorf("failed to delete forward rule '%s': %v", rule, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
h.rules = make(map[string]RulePair)
|
||||||
|
return nberrors.FormatErrorOrNil(mErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Manager) Rules() []firewall.ForwardRule {
|
||||||
|
h.rulesMu.Lock()
|
||||||
|
defer h.rulesMu.Unlock()
|
||||||
|
|
||||||
|
rules := make([]firewall.ForwardRule, 0, len(h.rules))
|
||||||
|
for _, rulePair := range h.rules {
|
||||||
|
rules = append(rules, rulePair.ForwardRule)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
281
client/internal/ingressgw/manager_test.go
Normal file
281
client/internal/ingressgw/manager_test.go
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
package ingressgw
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ firewall.Rule = (*MocFwRule)(nil)
|
||||||
|
_ DNATFirewall = &MockDNATFirewall{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type MocFwRule struct {
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MocFwRule) ID() string {
|
||||||
|
return string(m.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
type MockDNATFirewall struct {
|
||||||
|
throwError bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDNATFirewall) AddDNATRule(fwdRule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.throwError {
|
||||||
|
return nil, fmt.Errorf("moc error")
|
||||||
|
}
|
||||||
|
|
||||||
|
fwRule := &MocFwRule{
|
||||||
|
id: fwdRule.ID(),
|
||||||
|
}
|
||||||
|
return fwRule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDNATFirewall) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.throwError {
|
||||||
|
return fmt.Errorf("moc error")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockDNATFirewall) forceToThrowErrors() {
|
||||||
|
m.throwError = true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_AddRule(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
|
||||||
|
updates := []firewall.ForwardRule{
|
||||||
|
{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}}
|
||||||
|
|
||||||
|
if err := mgr.Update(updates); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != len(updates) {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_UpdateRule(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleUDP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleUDP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 1 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].TranslatedAddress.String() != ruleUDP.TranslatedAddress.String() {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].TranslatedPort.String() != ruleUDP.TranslatedPort.String() {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].DestinationPort.String() != ruleUDP.DestinationPort.String() {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
if rules[0].Protocol != ruleUDP.Protocol {
|
||||||
|
t.Errorf("unexpected rule: %v", rules[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_ExtendRules(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleUDP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 2 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_UnderlingError(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleUDP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolUDP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.2"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fw.forceToThrowErrors()
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP, ruleUDP}); err == nil {
|
||||||
|
t.Errorf("expected error")
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 1 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Cleanup(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 0 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_DeleteBrokenRule(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
|
||||||
|
// force to throw errors when Add DNAT Rule
|
||||||
|
fw.forceToThrowErrors()
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err == nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 0 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
// simulate that to remove a broken rule
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Close(); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Close(t *testing.T) {
|
||||||
|
fw := &MockDNATFirewall{}
|
||||||
|
mgr := NewManager(fw)
|
||||||
|
|
||||||
|
port, _ := firewall.NewPort(8080)
|
||||||
|
ruleTCP := firewall.ForwardRule{
|
||||||
|
Protocol: firewall.ProtocolTCP,
|
||||||
|
DestinationPort: *port,
|
||||||
|
TranslatedAddress: netip.MustParseAddr("172.16.254.1"),
|
||||||
|
TranslatedPort: *port,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Update([]firewall.ForwardRule{ruleTCP}); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := mgr.Close(); err != nil {
|
||||||
|
t.Errorf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules := mgr.Rules()
|
||||||
|
if len(rules) != 0 {
|
||||||
|
t.Errorf("unexpected rules count: %d", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
58
client/internal/message_convert.go
Normal file
58
client/internal/message_convert.go
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertToFirewallProtocol(protocol mgmProto.RuleProtocol) (firewallManager.Protocol, error) {
|
||||||
|
switch protocol {
|
||||||
|
case mgmProto.RuleProtocol_TCP:
|
||||||
|
return firewallManager.ProtocolTCP, nil
|
||||||
|
case mgmProto.RuleProtocol_UDP:
|
||||||
|
return firewallManager.ProtocolUDP, nil
|
||||||
|
case mgmProto.RuleProtocol_ICMP:
|
||||||
|
return firewallManager.ProtocolICMP, nil
|
||||||
|
case mgmProto.RuleProtocol_ALL:
|
||||||
|
return firewallManager.ProtocolALL, nil
|
||||||
|
default:
|
||||||
|
return "", fmt.Errorf("invalid protocol type: %s", protocol.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertPortInfo(portInfo *mgmProto.PortInfo) (*firewallManager.Port, error) {
|
||||||
|
if portInfo == nil {
|
||||||
|
return nil, errors.New("portInfo cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInfo.GetPort() != 0 {
|
||||||
|
return firewallManager.NewPort(int(portInfo.GetPort()))
|
||||||
|
}
|
||||||
|
|
||||||
|
if portInfo.GetRange() != nil {
|
||||||
|
return firewallManager.NewPort(int(portInfo.GetRange().Start), int(portInfo.GetRange().End))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("invalid portInfo: %v", portInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertToIP(rawIP []byte) (netip.Addr, error) {
|
||||||
|
if rawIP == nil {
|
||||||
|
return netip.Addr{}, errors.New("input bytes cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rawIP) != net.IPv4len && len(rawIP) != net.IPv6len {
|
||||||
|
return netip.Addr{}, fmt.Errorf("invalid IP length: %d", len(rawIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(rawIP) == net.IPv4len {
|
||||||
|
return netip.AddrFrom4([4]byte(rawIP)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return netip.AddrFrom16([16]byte(rawIP)), nil
|
||||||
|
}
|
||||||
306
client/internal/netflow/conntrack/conntrack.go
Normal file
306
client/internal/netflow/conntrack/conntrack.go
Normal file
@@ -0,0 +1,306 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
nfct "github.com/ti-mo/conntrack"
|
||||||
|
"github.com/ti-mo/netfilter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultChannelSize = 100
|
||||||
|
|
||||||
|
// ConnTrack manages kernel-based conntrack events
|
||||||
|
type ConnTrack struct {
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
|
iface nftypes.IFaceMapper
|
||||||
|
|
||||||
|
conn *nfct.Conn
|
||||||
|
mux sync.Mutex
|
||||||
|
|
||||||
|
instanceID uuid.UUID
|
||||||
|
started bool
|
||||||
|
done chan struct{}
|
||||||
|
sysctlModified bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new connection tracker that interfaces with the kernel's conntrack system
|
||||||
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack {
|
||||||
|
return &ConnTrack{
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
iface: iface,
|
||||||
|
instanceID: uuid.New(),
|
||||||
|
started: false,
|
||||||
|
done: make(chan struct{}, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins tracking connections by listening for conntrack events. This method is idempotent.
|
||||||
|
func (c *ConnTrack) Start(enableCounters bool) error {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.started {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Starting conntrack event listening")
|
||||||
|
|
||||||
|
if enableCounters {
|
||||||
|
c.EnableAccounting()
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := nfct.Dial(nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("dial conntrack: %w", err)
|
||||||
|
}
|
||||||
|
c.conn = conn
|
||||||
|
|
||||||
|
events := make(chan nfct.Event, defaultChannelSize)
|
||||||
|
errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{
|
||||||
|
netfilter.GroupCTNew,
|
||||||
|
netfilter.GroupCTDestroy,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
|
}
|
||||||
|
c.conn = nil
|
||||||
|
return fmt.Errorf("start conntrack listener: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.started = true
|
||||||
|
|
||||||
|
go c.receiverRoutine(events, errChan)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case event := <-events:
|
||||||
|
c.handleEvent(event)
|
||||||
|
case err := <-errChan:
|
||||||
|
log.Errorf("Error from conntrack event listener: %v", err)
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-c.done:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the connection tracking. This method is idempotent.
|
||||||
|
func (c *ConnTrack) Stop() {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if !c.started {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Stopping conntrack event listening")
|
||||||
|
|
||||||
|
select {
|
||||||
|
case c.done <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
log.Errorf("Error closing conntrack connection: %v", err)
|
||||||
|
}
|
||||||
|
c.conn = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.started = false
|
||||||
|
|
||||||
|
c.RestoreAccounting()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops listening for events and cleans up resources
|
||||||
|
func (c *ConnTrack) Close() error {
|
||||||
|
c.mux.Lock()
|
||||||
|
defer c.mux.Unlock()
|
||||||
|
|
||||||
|
if c.started {
|
||||||
|
select {
|
||||||
|
case c.done <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
err := c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
c.started = false
|
||||||
|
|
||||||
|
c.RestoreAccounting()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("close conntrack: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleEvent processes incoming conntrack events
|
||||||
|
func (c *ConnTrack) handleEvent(event nfct.Event) {
|
||||||
|
if event.Flow == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Type != nfct.EventNew && event.Type != nfct.EventDestroy {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flow := *event.Flow
|
||||||
|
|
||||||
|
proto := nftypes.Protocol(flow.TupleOrig.Proto.Protocol)
|
||||||
|
if proto == nftypes.ProtocolUnknown {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srcIP := flow.TupleOrig.IP.SourceAddress
|
||||||
|
dstIP := flow.TupleOrig.IP.DestinationAddress
|
||||||
|
|
||||||
|
if !c.relevantFlow(srcIP, dstIP) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var srcPort, dstPort uint16
|
||||||
|
var icmpType, icmpCode uint8
|
||||||
|
|
||||||
|
switch proto {
|
||||||
|
case nftypes.TCP, nftypes.UDP, nftypes.SCTP:
|
||||||
|
srcPort = flow.TupleOrig.Proto.SourcePort
|
||||||
|
dstPort = flow.TupleOrig.Proto.DestinationPort
|
||||||
|
case nftypes.ICMP:
|
||||||
|
icmpType = flow.TupleOrig.Proto.ICMPType
|
||||||
|
icmpCode = flow.TupleOrig.Proto.ICMPCode
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := c.getFlowID(flow.ID)
|
||||||
|
direction := c.inferDirection(srcIP, dstIP)
|
||||||
|
|
||||||
|
eventType := nftypes.TypeStart
|
||||||
|
eventStr := "New"
|
||||||
|
|
||||||
|
if event.Type == nfct.EventDestroy {
|
||||||
|
eventType = nftypes.TypeEnd
|
||||||
|
eventStr = "Ended"
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("%s %s %s connection: %s:%d -> %s:%d", eventStr, direction, proto, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
c.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: eventType,
|
||||||
|
Direction: direction,
|
||||||
|
Protocol: proto,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
RxPackets: c.mapRxPackets(flow, direction),
|
||||||
|
TxPackets: c.mapTxPackets(flow, direction),
|
||||||
|
RxBytes: c.mapRxBytes(flow, direction),
|
||||||
|
TxBytes: c.mapTxBytes(flow, direction),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// relevantFlow checks if the flow is related to the specified interface
|
||||||
|
func (c *ConnTrack) relevantFlow(srcIP, dstIP netip.Addr) bool {
|
||||||
|
// TODO: filter traffic by interface
|
||||||
|
|
||||||
|
wgnet := c.iface.Address().Network
|
||||||
|
if !wgnet.Contains(srcIP.AsSlice()) && !wgnet.Contains(dstIP.AsSlice()) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapRxPackets maps packet counts to RX based on flow direction
|
||||||
|
func (c *ConnTrack) mapRxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersOrig is from external to us (RX)
|
||||||
|
// For Egress: CountersReply is from external to us (RX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersOrig.Packets
|
||||||
|
}
|
||||||
|
return flow.CountersReply.Packets
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapTxPackets maps packet counts to TX based on flow direction
|
||||||
|
func (c *ConnTrack) mapTxPackets(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersReply is from us to external (TX)
|
||||||
|
// For Egress: CountersOrig is from us to external (TX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersReply.Packets
|
||||||
|
}
|
||||||
|
return flow.CountersOrig.Packets
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapRxBytes maps byte counts to RX based on flow direction
|
||||||
|
func (c *ConnTrack) mapRxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersOrig is from external to us (RX)
|
||||||
|
// For Egress: CountersReply is from external to us (RX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersOrig.Bytes
|
||||||
|
}
|
||||||
|
return flow.CountersReply.Bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// mapTxBytes maps byte counts to TX based on flow direction
|
||||||
|
func (c *ConnTrack) mapTxBytes(flow nfct.Flow, direction nftypes.Direction) uint64 {
|
||||||
|
// For Ingress: CountersReply is from us to external (TX)
|
||||||
|
// For Egress: CountersOrig is from us to external (TX)
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
return flow.CountersReply.Bytes
|
||||||
|
}
|
||||||
|
return flow.CountersOrig.Bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// getFlowID creates a unique UUID based on the conntrack ID and instance ID
|
||||||
|
func (c *ConnTrack) getFlowID(conntrackID uint32) uuid.UUID {
|
||||||
|
var buf [4]byte
|
||||||
|
binary.BigEndian.PutUint32(buf[:], conntrackID)
|
||||||
|
return uuid.NewSHA1(c.instanceID, buf[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnTrack) inferDirection(srcIP, dstIP netip.Addr) nftypes.Direction {
|
||||||
|
wgaddr := c.iface.Address().IP
|
||||||
|
wgnetwork := c.iface.Address().Network
|
||||||
|
src, dst := srcIP.AsSlice(), dstIP.AsSlice()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case wgaddr.Equal(src):
|
||||||
|
return nftypes.Egress
|
||||||
|
case wgaddr.Equal(dst):
|
||||||
|
return nftypes.Ingress
|
||||||
|
case wgnetwork.Contains(src):
|
||||||
|
// netbird network -> resource network
|
||||||
|
return nftypes.Ingress
|
||||||
|
case wgnetwork.Contains(dst):
|
||||||
|
// resource network -> netbird network
|
||||||
|
return nftypes.Egress
|
||||||
|
|
||||||
|
// TODO: handle site2site traffic
|
||||||
|
}
|
||||||
|
|
||||||
|
return nftypes.DirectionUnknown
|
||||||
|
}
|
||||||
9
client/internal/netflow/conntrack/conntrack_nonlinux.go
Normal file
9
client/internal/netflow/conntrack/conntrack_nonlinux.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
|
||||||
|
func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) nftypes.ConnTracker {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
73
client/internal/netflow/conntrack/sysctl.go
Normal file
73
client/internal/netflow/conntrack/sysctl.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build linux && !android
|
||||||
|
|
||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// conntrackAcctPath is the sysctl path for conntrack accounting
|
||||||
|
conntrackAcctPath = "net.netfilter.nf_conntrack_acct"
|
||||||
|
)
|
||||||
|
|
||||||
|
// EnableAccounting ensures that connection tracking accounting is enabled in the kernel.
|
||||||
|
func (c *ConnTrack) EnableAccounting() {
|
||||||
|
// haven't restored yet
|
||||||
|
if c.sysctlModified {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modified, err := setSysctl(conntrackAcctPath, 1)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to enable conntrack accounting: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.sysctlModified = modified
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreAccounting restores the connection tracking accounting setting to its original value.
|
||||||
|
func (c *ConnTrack) RestoreAccounting() {
|
||||||
|
if !c.sysctlModified {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := setSysctl(conntrackAcctPath, 0); err != nil {
|
||||||
|
log.Warnf("Failed to restore conntrack accounting: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sysctlModified = false
|
||||||
|
}
|
||||||
|
|
||||||
|
// setSysctl sets a sysctl configuration and returns whether it was modified.
|
||||||
|
func setSysctl(key string, desiredValue int) (bool, error) {
|
||||||
|
path := fmt.Sprintf("/proc/sys/%s", strings.ReplaceAll(key, ".", "/"))
|
||||||
|
|
||||||
|
currentValue, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("read sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
currentV, err := strconv.Atoi(strings.TrimSpace(string(currentValue)))
|
||||||
|
if err != nil && len(currentValue) > 0 {
|
||||||
|
return false, fmt.Errorf("convert current value to int: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if currentV == desiredValue {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:gosec
|
||||||
|
if err := os.WriteFile(path, []byte(strconv.Itoa(desiredValue)), 0644); err != nil {
|
||||||
|
return false, fmt.Errorf("write sysctl %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Set sysctl %s from %d to %d", key, currentV, desiredValue)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
125
client/internal/netflow/logger/logger.go
Normal file
125
client/internal/netflow/logger/logger.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/store"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rcvChan chan *types.EventFields
|
||||||
|
type Logger struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
enabled atomic.Bool
|
||||||
|
rcvChan atomic.Pointer[rcvChan]
|
||||||
|
cancelReceiver context.CancelFunc
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
Store types.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(ctx context.Context, statusRecorder *peer.Status) *Logger {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
return &Logger{
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
Store: store.NewMemoryStore(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) StoreEvent(flowEvent types.EventFields) {
|
||||||
|
if !l.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c := l.rcvChan.Load()
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case *c <- &flowEvent:
|
||||||
|
default:
|
||||||
|
// todo: we should collect or log on this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Enable() {
|
||||||
|
go l.startReceiver()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) startReceiver() {
|
||||||
|
if l.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mux.Lock()
|
||||||
|
ctx, cancel := context.WithCancel(l.ctx)
|
||||||
|
l.cancelReceiver = cancel
|
||||||
|
l.mux.Unlock()
|
||||||
|
|
||||||
|
c := make(rcvChan, 100)
|
||||||
|
l.rcvChan.Store(&c)
|
||||||
|
l.enabled.Store(true)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Info("flow Memory store receiver stopped")
|
||||||
|
return
|
||||||
|
case eventFields := <-c:
|
||||||
|
id := uuid.New()
|
||||||
|
event := types.Event{
|
||||||
|
ID: id,
|
||||||
|
EventFields: *eventFields,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
}
|
||||||
|
srcResId, dstResId := l.statusRecorder.CheckRoutes(event.SourceIP, event.DestIP, event.Direction)
|
||||||
|
event.SourceResourceID = []byte(srcResId)
|
||||||
|
event.DestResourceID = []byte(dstResId)
|
||||||
|
l.Store.StoreEvent(&event)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Disable() {
|
||||||
|
l.stop()
|
||||||
|
l.Store.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) stop() {
|
||||||
|
if !l.enabled.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
l.enabled.Store(false)
|
||||||
|
l.mux.Lock()
|
||||||
|
if l.cancelReceiver != nil {
|
||||||
|
l.cancelReceiver()
|
||||||
|
l.cancelReceiver = nil
|
||||||
|
}
|
||||||
|
l.rcvChan.Store(nil)
|
||||||
|
l.mux.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) GetEvents() []*types.Event {
|
||||||
|
return l.Store.GetEvents()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) DeleteEvents(ids []uuid.UUID) {
|
||||||
|
l.Store.DeleteEvents(ids)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) Close() {
|
||||||
|
l.stop()
|
||||||
|
l.cancel()
|
||||||
|
}
|
||||||
67
client/internal/netflow/logger/logger_test.go
Normal file
67
client/internal/netflow/logger/logger_test.go
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
package logger_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/logger"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestStore(t *testing.T) {
|
||||||
|
logger := logger.New(context.Background(), nil)
|
||||||
|
logger.Enable()
|
||||||
|
|
||||||
|
event := types.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: types.TypeStart,
|
||||||
|
Direction: types.Ingress,
|
||||||
|
Protocol: 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
wait := func() { time.Sleep(time.Millisecond) }
|
||||||
|
wait()
|
||||||
|
logger.StoreEvent(event)
|
||||||
|
wait()
|
||||||
|
|
||||||
|
allEvents := logger.GetEvents()
|
||||||
|
matched := false
|
||||||
|
for _, e := range allEvents {
|
||||||
|
if e.EventFields.FlowID == event.FlowID {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Errorf("didn't match any event")
|
||||||
|
}
|
||||||
|
|
||||||
|
// test disable
|
||||||
|
logger.Disable()
|
||||||
|
wait()
|
||||||
|
logger.StoreEvent(event)
|
||||||
|
wait()
|
||||||
|
allEvents = logger.GetEvents()
|
||||||
|
if len(allEvents) != 0 {
|
||||||
|
t.Errorf("expected 0 events, got %d", len(allEvents))
|
||||||
|
}
|
||||||
|
|
||||||
|
// test re-enable
|
||||||
|
logger.Enable()
|
||||||
|
wait()
|
||||||
|
logger.StoreEvent(event)
|
||||||
|
wait()
|
||||||
|
|
||||||
|
allEvents = logger.GetEvents()
|
||||||
|
matched = false
|
||||||
|
for _, e := range allEvents {
|
||||||
|
if e.EventFields.FlowID == event.FlowID {
|
||||||
|
matched = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !matched {
|
||||||
|
t.Errorf("didn't match any event")
|
||||||
|
}
|
||||||
|
}
|
||||||
235
client/internal/netflow/manager.go
Normal file
235
client/internal/netflow/manager.go
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
package netflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow/logger"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/flow/client"
|
||||||
|
"github.com/netbirdio/netbird/flow/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager handles netflow tracking and logging
|
||||||
|
type Manager struct {
|
||||||
|
mux sync.Mutex
|
||||||
|
logger nftypes.FlowLogger
|
||||||
|
flowConfig *nftypes.FlowConfig
|
||||||
|
conntrack nftypes.ConnTracker
|
||||||
|
ctx context.Context
|
||||||
|
receiverClient *client.GRPCClient
|
||||||
|
publicKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new netflow manager
|
||||||
|
func NewManager(ctx context.Context, iface nftypes.IFaceMapper, publicKey []byte, statusRecorder *peer.Status) *Manager {
|
||||||
|
flowLogger := logger.New(ctx, statusRecorder)
|
||||||
|
|
||||||
|
var ct nftypes.ConnTracker
|
||||||
|
if runtime.GOOS == "linux" && iface != nil && !iface.IsUserspaceBind() {
|
||||||
|
ct = conntrack.New(flowLogger, iface)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Manager{
|
||||||
|
logger: flowLogger,
|
||||||
|
conntrack: ct,
|
||||||
|
ctx: ctx,
|
||||||
|
publicKey: publicKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update applies new flow configuration settings
|
||||||
|
// needsNewClient checks if a new client needs to be created
|
||||||
|
func (m *Manager) needsNewClient(previous *nftypes.FlowConfig) bool {
|
||||||
|
current := m.flowConfig
|
||||||
|
return previous == nil ||
|
||||||
|
!previous.Enabled ||
|
||||||
|
previous.TokenPayload != current.TokenPayload ||
|
||||||
|
previous.TokenSignature != current.TokenSignature ||
|
||||||
|
previous.URL != current.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// enableFlow starts components for flow tracking
|
||||||
|
func (m *Manager) enableFlow(previous *nftypes.FlowConfig) error {
|
||||||
|
// first make sender ready so events don't pile up
|
||||||
|
if m.needsNewClient(previous) {
|
||||||
|
if m.receiverClient != nil {
|
||||||
|
if err := m.receiverClient.Close(); err != nil {
|
||||||
|
log.Warnf("error closing previous flow client: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flowClient, err := client.NewClient(m.flowConfig.URL, m.flowConfig.TokenPayload, m.flowConfig.TokenSignature, m.flowConfig.Interval)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create client: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("flow client configured to connect to %s", m.flowConfig.URL)
|
||||||
|
|
||||||
|
m.receiverClient = flowClient
|
||||||
|
go m.receiveACKs(flowClient)
|
||||||
|
go m.startSender()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Enable()
|
||||||
|
|
||||||
|
if m.conntrack != nil {
|
||||||
|
if err := m.conntrack.Start(m.flowConfig.Counters); err != nil {
|
||||||
|
return fmt.Errorf("start conntrack: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// disableFlow stops components for flow tracking
|
||||||
|
func (m *Manager) disableFlow() error {
|
||||||
|
if m.conntrack != nil {
|
||||||
|
m.conntrack.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Disable()
|
||||||
|
|
||||||
|
if m.receiverClient != nil {
|
||||||
|
return m.receiverClient.Close()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update applies new flow configuration settings
|
||||||
|
func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||||
|
if update == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
previous := m.flowConfig
|
||||||
|
m.flowConfig = update
|
||||||
|
|
||||||
|
if update.Enabled {
|
||||||
|
return m.enableFlow(previous)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.disableFlow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close cleans up all resources
|
||||||
|
func (m *Manager) Close() {
|
||||||
|
m.mux.Lock()
|
||||||
|
defer m.mux.Unlock()
|
||||||
|
|
||||||
|
if m.conntrack != nil {
|
||||||
|
m.conntrack.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.receiverClient != nil {
|
||||||
|
if err := m.receiverClient.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close receiver client: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLogger returns the flow logger
|
||||||
|
func (m *Manager) GetLogger() nftypes.FlowLogger {
|
||||||
|
return m.logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) startSender() {
|
||||||
|
ticker := time.NewTicker(m.flowConfig.Interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-m.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
events := m.logger.GetEvents()
|
||||||
|
for _, event := range events {
|
||||||
|
if err := m.send(event); err != nil {
|
||||||
|
log.Errorf("failed to send flow event to server: %s", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Tracef("sent flow event: %s", event.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) receiveACKs(client *client.GRPCClient) {
|
||||||
|
err := client.Receive(m.ctx, m.flowConfig.Interval, func(ack *proto.FlowEventAck) error {
|
||||||
|
log.Tracef("received flow event ack: %s", ack.EventId)
|
||||||
|
m.logger.DeleteEvents([]uuid.UUID{uuid.UUID(ack.EventId)})
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil && !errors.Is(err, context.Canceled) {
|
||||||
|
log.Errorf("failed to receive flow event ack: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) send(event *nftypes.Event) error {
|
||||||
|
m.mux.Lock()
|
||||||
|
client := m.receiverClient
|
||||||
|
m.mux.Unlock()
|
||||||
|
|
||||||
|
if client == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return client.Send(toProtoEvent(m.publicKey, event))
|
||||||
|
}
|
||||||
|
|
||||||
|
func toProtoEvent(publicKey []byte, event *nftypes.Event) *proto.FlowEvent {
|
||||||
|
protoEvent := &proto.FlowEvent{
|
||||||
|
EventId: event.ID[:],
|
||||||
|
Timestamp: timestamppb.New(event.Timestamp),
|
||||||
|
PublicKey: publicKey,
|
||||||
|
FlowFields: &proto.FlowFields{
|
||||||
|
FlowId: event.FlowID[:],
|
||||||
|
RuleId: event.RuleID,
|
||||||
|
Type: proto.Type(event.Type),
|
||||||
|
Direction: proto.Direction(event.Direction),
|
||||||
|
Protocol: uint32(event.Protocol),
|
||||||
|
SourceIp: event.SourceIP.AsSlice(),
|
||||||
|
DestIp: event.DestIP.AsSlice(),
|
||||||
|
RxPackets: event.RxPackets,
|
||||||
|
TxPackets: event.TxPackets,
|
||||||
|
RxBytes: event.RxBytes,
|
||||||
|
TxBytes: event.TxBytes,
|
||||||
|
SourceResourceId: event.SourceResourceID,
|
||||||
|
DestResourceId: event.DestResourceID,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Protocol == nftypes.ICMP {
|
||||||
|
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_IcmpInfo{
|
||||||
|
IcmpInfo: &proto.ICMPInfo{
|
||||||
|
IcmpType: uint32(event.ICMPType),
|
||||||
|
IcmpCode: uint32(event.ICMPCode),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return protoEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
protoEvent.FlowFields.ConnectionInfo = &proto.FlowFields_PortInfo{
|
||||||
|
PortInfo: &proto.PortInfo{
|
||||||
|
SourcePort: uint32(event.SourcePort),
|
||||||
|
DestPort: uint32(event.DestPort),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return protoEvent
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user