mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client] Clamp MSS on outbound traffic (#4735)
This commit is contained in:
@@ -15,13 +15,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||||
|
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return fm, err
|
return fm, err
|
||||||
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||||
fm, err := createFW(iface)
|
fm, err := createFW(iface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create firewall: %s", err)
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
}
|
}
|
||||||
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
|
|||||||
return fm, nil
|
return fm, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
||||||
switch check() {
|
switch check() {
|
||||||
case IPTABLES:
|
case IPTABLES:
|
||||||
log.Info("creating an iptables firewall manager")
|
log.Info("creating an iptables firewall manager")
|
||||||
return nbiptables.Create(iface)
|
return nbiptables.Create(iface, mtu)
|
||||||
case NFTABLES:
|
case NFTABLES:
|
||||||
log.Info("creating an nftables firewall manager")
|
log.Info("creating an nftables firewall manager")
|
||||||
return nbnftables.Create(iface)
|
return nbnftables.Create(iface, mtu)
|
||||||
default:
|
default:
|
||||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||||
return nil, errors.New("no firewall manager found")
|
return nil, errors.New("no firewall manager found")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ type iFaceMapper interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create iptables firewall manager
|
// Create iptables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("init iptables: %w", err)
|
return nil, fmt.Errorf("init iptables: %w", err)
|
||||||
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
ipv4Client: iptablesClient,
|
ipv4Client: iptablesClient,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.router, err = newRouter(iptablesClient, wgIface)
|
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create router: %w", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
}
|
}
|
||||||
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
stateManager.RegisterState(state)
|
stateManager.RegisterState(state)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
|||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|||||||
@@ -30,17 +30,20 @@ const (
|
|||||||
|
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
chainPREROUTING = "PREROUTING"
|
chainPREROUTING = "PREROUTING"
|
||||||
|
chainFORWARD = "FORWARD"
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||||
chainRTPRE = "NETBIRD-RT-PRE"
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
chainRTRDR = "NETBIRD-RT-RDR"
|
chainRTRDR = "NETBIRD-RT-RDR"
|
||||||
|
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
jumpManglePre = "jump-mangle-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNatPre = "jump-nat-pre"
|
jumpNatPre = "jump-nat-pre"
|
||||||
jumpNatPost = "jump-nat-post"
|
jumpNatPost = "jump-nat-post"
|
||||||
|
jumpMSSClamp = "jump-mss-clamp"
|
||||||
markManglePre = "mark-mangle-pre"
|
markManglePre = "mark-mangle-pre"
|
||||||
markManglePost = "mark-mangle-post"
|
markManglePost = "mark-mangle-post"
|
||||||
matchSet = "--match-set"
|
matchSet = "--match-set"
|
||||||
@@ -48,6 +51,9 @@ const (
|
|||||||
dnatSuffix = "_dnat"
|
dnatSuffix = "_dnat"
|
||||||
snatSuffix = "_snat"
|
snatSuffix = "_snat"
|
||||||
fwdSuffix = "_fwd"
|
fwdSuffix = "_fwd"
|
||||||
|
|
||||||
|
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||||
|
ipTCPHeaderMinSize = 40
|
||||||
)
|
)
|
||||||
|
|
||||||
type ruleInfo struct {
|
type ruleInfo struct {
|
||||||
@@ -77,16 +83,18 @@ type router struct {
|
|||||||
ipsetCounter *ipsetCounter
|
ipsetCounter *ipsetCounter
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
|
mtu uint16
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||||
r := &router{
|
r := &router{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
mtu: mtu,
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTNAT, tableNat},
|
||||||
{chainRTRDR, tableNat},
|
{chainRTRDR, tableNat},
|
||||||
|
{chainRTMSSCLAMP, tableMangle},
|
||||||
} {
|
} {
|
||||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
|
|||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTNAT, tableNat},
|
||||||
{chainRTRDR, tableNat},
|
{chainRTRDR, tableNat},
|
||||||
|
{chainRTMSSCLAMP, tableMangle},
|
||||||
} {
|
} {
|
||||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
|
|||||||
return fmt.Errorf("add jump rules: %w", err)
|
return fmt.Errorf("add jump rules: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
|
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
|
// TODO: Add IPv6 support
|
||||||
|
func (r *router) addMSSClampingRules() error {
|
||||||
|
mss := r.mtu - ipTCPHeaderMinSize
|
||||||
|
|
||||||
|
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||||
|
jumpRule := []string{
|
||||||
|
"-j", chainRTMSSCLAMP,
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
|
||||||
|
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpMSSClamp] = jumpRule
|
||||||
|
|
||||||
|
ruleOut := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", "tcp",
|
||||||
|
"--tcp-flags", "SYN,RST", "SYN",
|
||||||
|
"-j", "TCPMSS",
|
||||||
|
"--set-mss", fmt.Sprintf("%d", mss),
|
||||||
|
}
|
||||||
|
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
|
||||||
|
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||||
|
}
|
||||||
|
r.rules["mss-clamp-out"] = ruleOut
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) insertEstablishedRule(chain string) error {
|
func (r *router) insertEstablishedRule(chain string) error {
|
||||||
establishedRule := getConntrackEstablished()
|
establishedRule := getConntrackEstablished()
|
||||||
|
|
||||||
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanJumpRules() error {
|
func (r *router) cleanJumpRules() error {
|
||||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
var table, chain string
|
var table, chain string
|
||||||
switch ruleKey {
|
switch ruleKey {
|
||||||
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
|
|||||||
case jumpNatPre:
|
case jumpNatPre:
|
||||||
table = tableNat
|
table = tableNat
|
||||||
chain = chainPREROUTING
|
chain = chainPREROUTING
|
||||||
|
case jumpMSSClamp:
|
||||||
|
table = tableMangle
|
||||||
|
chain = chainFORWARD
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "should return a valid iptables manager")
|
require.NoError(t, err, "should return a valid iptables manager")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Now 5 rules:
|
|
||||||
// 1. established rule forward in
|
// 1. established rule forward in
|
||||||
// 2. estbalished rule forward out
|
// 2. estbalished rule forward out
|
||||||
// 3. jump rule to POST nat chain
|
// 3. jump rule to POST nat chain
|
||||||
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
// 7. static return masquerade rule
|
// 7. static return masquerade rule
|
||||||
// 8. mangle prerouting mark rule
|
// 8. mangle prerouting mark rule
|
||||||
// 9. mangle postrouting mark rule
|
// 9. mangle postrouting mark rule
|
||||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
// 10. jump rule to MSS clamping chain
|
||||||
|
// 11. MSS clamping rule for outbound traffic
|
||||||
|
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "failed to init iptables client")
|
require.NoError(t, err, "failed to init iptables client")
|
||||||
|
|
||||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
|
|
||||||
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
|
|
||||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "shouldn't return error")
|
require.NoError(t, err, "shouldn't return error")
|
||||||
require.NoError(t, manager.init(nil))
|
require.NoError(t, manager.init(nil))
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||||
require.NoError(t, err, "Failed to create iptables client")
|
require.NoError(t, err, "Failed to create iptables client")
|
||||||
|
|
||||||
r, err := newRouter(iptablesClient, ifaceMock)
|
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create router manager")
|
require.NoError(t, err, "Failed to create router manager")
|
||||||
require.NoError(t, r.init(nil))
|
require.NoError(t, r.init(nil))
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -11,6 +12,7 @@ type InterfaceState struct {
|
|||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress wgaddr.Address `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
MTU uint16 `json:"mtu"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
ipt, err := Create(s.InterfaceState)
|
mtu := s.InterfaceState.MTU
|
||||||
|
if mtu == 0 {
|
||||||
|
mtu = iface.DefaultMTU
|
||||||
|
}
|
||||||
|
ipt, err := Create(s.InterfaceState, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create iptables manager: %w", err)
|
return fmt.Errorf("create iptables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ type Manager struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create nftables firewall manager
|
// Create nftables firewall manager
|
||||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
rConn: &nftables.Conn{},
|
rConn: &nftables.Conn{},
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
@@ -53,7 +53,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
m.router, err = newRouter(workTable, wgIface)
|
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create router: %w", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
}
|
}
|
||||||
@@ -93,6 +93,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
WGAddress: m.wgIface.Address(),
|
WGAddress: m.wgIface.Address(),
|
||||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||||
|
MTU: m.router.mtu,
|
||||||
},
|
},
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
log.Errorf("failed to update state: %v", err)
|
log.Errorf("failed to update state: %v", err)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
|||||||
func TestNftablesManager(t *testing.T) {
|
func TestNftablesManager(t *testing.T) {
|
||||||
|
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
func TestNftablesManagerRuleOrder(t *testing.T) {
|
func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||||
// This test verifies rule insertion order in nftables peer ACLs
|
// This test verifies rule insertion order in nftables peer ACLs
|
||||||
// We add accept rule first, then deny rule to test ordering behavior
|
// We add accept rule first, then deny rule to test ordering behavior
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
// just check on the local interface
|
// just check on the local interface
|
||||||
manager, err := Create(mock)
|
manager, err := Create(mock, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
stdout, stderr := runIptablesSave(t)
|
stdout, stderr := runIptablesSave(t)
|
||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "failed to create manager")
|
require.NoError(t, err, "failed to create manager")
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/google/nftables/xt"
|
"github.com/google/nftables/xt"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@@ -32,12 +33,16 @@ const (
|
|||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||||
chainNameForward = "FORWARD"
|
chainNameForward = "FORWARD"
|
||||||
|
chainNameMangleForward = "netbird-mangle-forward"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
|
|
||||||
dnatSuffix = "_dnat"
|
dnatSuffix = "_dnat"
|
||||||
snatSuffix = "_snat"
|
snatSuffix = "_snat"
|
||||||
|
|
||||||
|
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||||
|
ipTCPHeaderMinSize = 40
|
||||||
)
|
)
|
||||||
|
|
||||||
const refreshRulesMapError = "refresh rules map: %w"
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
@@ -63,9 +68,10 @@ type router struct {
|
|||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
|
mtu uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||||
r := &router{
|
r := &router{
|
||||||
conn: &nftables.Conn{},
|
conn: &nftables.Conn{},
|
||||||
workTable: workTable,
|
workTable: workTable,
|
||||||
@@ -73,6 +79,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
|||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -220,11 +227,23 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameMangleForward,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookForward,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
// Add the single NAT rule that matches on mark
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
return fmt.Errorf("add single nat rule: %v", err)
|
return fmt.Errorf("add single nat rule: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.addMSSClampingRules(); err != nil {
|
||||||
|
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.acceptForwardRules(); err != nil {
|
if err := r.acceptForwardRules(); err != nil {
|
||||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||||
}
|
}
|
||||||
@@ -745,6 +764,83 @@ func (r *router) addPostroutingRules() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||||
|
// TODO: Add IPv6 support
|
||||||
|
func (r *router) addMSSClampingRules() error {
|
||||||
|
mss := r.mtu - ipTCPHeaderMinSize
|
||||||
|
|
||||||
|
exprsOut := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyL4PROTO,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{unix.IPPROTO_TCP},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 13,
|
||||||
|
Len: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
DestRegister: 1,
|
||||||
|
SourceRegister: 1,
|
||||||
|
Len: 1,
|
||||||
|
Mask: []byte{0x02},
|
||||||
|
Xor: []byte{0x00},
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0x00},
|
||||||
|
},
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Exthdr{
|
||||||
|
DestRegister: 1,
|
||||||
|
Type: 2,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
Op: expr.ExthdrOpTcpopt,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpGt,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||||
|
},
|
||||||
|
&expr.Exthdr{
|
||||||
|
SourceRegister: 1,
|
||||||
|
Type: 2,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
Op: expr.ExthdrOpTcpopt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
r.conn.AddRule(&nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameMangleForward],
|
||||||
|
Exprs: exprsOut,
|
||||||
|
})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
for _, testCase := range test.InsertRuleTestCases {
|
for _, testCase := range test.InsertRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
// need fw manager to init both acl mgr and router for all chains to be present
|
// need fw manager to init both acl mgr and router for all chains to be present
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
|
|
||||||
for _, testCase := range test.RemoveRuleTestCases {
|
for _, testCase := range test.RemoveRuleTestCases {
|
||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock)
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create router")
|
require.NoError(t, err, "Failed to create router")
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
defer deleteWorkTable()
|
defer deleteWorkTable()
|
||||||
|
|
||||||
r, err := newRouter(workTable, ifaceMock)
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
require.NoError(t, err, "Failed to create router")
|
require.NoError(t, err, "Failed to create router")
|
||||||
require.NoError(t, r.init(workTable))
|
require.NoError(t, r.init(workTable))
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -10,6 +11,7 @@ type InterfaceState struct {
|
|||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress wgaddr.Address `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
|
MTU uint16 `json:"mtu"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Cleanup() error {
|
func (s *ShutdownState) Cleanup() error {
|
||||||
nft, err := Create(s.InterfaceState)
|
mtu := s.InterfaceState.MTU
|
||||||
|
if mtu == 0 {
|
||||||
|
mtu = iface.DefaultMTU
|
||||||
|
}
|
||||||
|
nft, err := Create(s.InterfaceState, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create nftables manager: %w", err)
|
return fmt.Errorf("create nftables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -27,7 +28,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
const layerTypeAll = 0
|
const (
|
||||||
|
layerTypeAll = 0
|
||||||
|
|
||||||
|
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||||
|
ipTCPHeaderMinSize = 40
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||||
@@ -36,6 +42,9 @@ const (
|
|||||||
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||||
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||||
|
|
||||||
|
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
||||||
|
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
||||||
|
|
||||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
@@ -122,6 +131,10 @@ type Manager struct {
|
|||||||
|
|
||||||
netstackServices map[serviceKey]struct{}
|
netstackServices map[serviceKey]struct{}
|
||||||
netstackServiceMutex sync.RWMutex
|
netstackServiceMutex sync.RWMutex
|
||||||
|
|
||||||
|
mtu uint16
|
||||||
|
mssClampValue uint16
|
||||||
|
mssClampEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -140,16 +153,16 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
return create(iface, nil, disableServerRoutes, flowLogger)
|
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
if nativeFirewall == nil {
|
if nativeFirewall == nil {
|
||||||
return nil, errors.New("native firewall is nil")
|
return nil, errors.New("native firewall is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -157,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
|
|||||||
return mgr, nil
|
return mgr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseCreateEnv() (bool, bool) {
|
func parseCreateEnv() (bool, bool, bool) {
|
||||||
var disableConntrack, enableLocalForwarding bool
|
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||||
var err error
|
var err error
|
||||||
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||||
disableConntrack, err = strconv.ParseBool(val)
|
disableConntrack, err = strconv.ParseBool(val)
|
||||||
@@ -177,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
|
||||||
|
disableMSSClamping, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
decoders: sync.Pool{
|
decoders: sync.Pool{
|
||||||
@@ -213,13 +232,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
|
mtu: mtu,
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
|
if !disableMSSClamping {
|
||||||
|
m.mssClampEnabled = true
|
||||||
|
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||||
|
}
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
@@ -227,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// netstack needs the forwarder for local traffic
|
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
if err := m.initForwarder(); err != nil {
|
if err := m.initForwarder(); err != nil {
|
||||||
log.Errorf("failed to initialize forwarder: %v", err)
|
log.Errorf("failed to initialize forwarder: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, fmt.Errorf("set filter: %w", err)
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
@@ -337,7 +357,7 @@ func (m *Manager) initForwarder() error {
|
|||||||
return errors.New("forwarding not supported")
|
return errors.New("forwarding not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
return fmt.Errorf("create forwarder: %w", err)
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
@@ -643,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
switch d.decoded[1] {
|
||||||
return true
|
case layers.LayerTypeUDP:
|
||||||
|
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||||
|
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||||
|
if m.mssClampEnabled {
|
||||||
|
m.clampTCPMSS(packetData, d)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
||||||
@@ -691,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
|||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
|
||||||
|
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
|
||||||
|
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||||
|
if !d.tcp.SYN {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(d.tcp.Options) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
mssOptionIndex := -1
|
||||||
|
var currentMSS uint16
|
||||||
|
for i, opt := range d.tcp.Options {
|
||||||
|
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||||
|
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||||
|
if currentMSS > m.mssClampValue {
|
||||||
|
mssOptionIndex = i
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if mssOptionIndex == -1 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderSize < 20 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||||
|
tcpHeaderStart := ipHeaderSize
|
||||||
|
tcpOptionsStart := tcpHeaderStart + 20
|
||||||
|
|
||||||
|
optOffset := tcpOptionsStart
|
||||||
|
for j := 0; j < mssOptionIndex; j++ {
|
||||||
|
switch d.tcp.Options[j].OptionType {
|
||||||
|
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
|
||||||
|
optOffset++
|
||||||
|
default:
|
||||||
|
optOffset += 2 + len(d.tcp.Options[j].OptionData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mssValueOffset := optOffset + 2
|
||||||
|
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||||
|
|
||||||
|
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
|
||||||
|
tcpLayer := packetData[tcpHeaderStart:]
|
||||||
|
tcpLength := len(packetData) - tcpHeaderStart
|
||||||
|
|
||||||
|
tcpLayer[16] = 0
|
||||||
|
tcpLayer[17] = 0
|
||||||
|
|
||||||
|
var pseudoSum uint32
|
||||||
|
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||||
|
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||||
|
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||||
|
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||||
|
pseudoSum += uint32(d.ip4.Protocol)
|
||||||
|
pseudoSum += uint32(tcpLength)
|
||||||
|
|
||||||
|
var sum uint32 = pseudoSum
|
||||||
|
for i := 0; i < tcpLength-1; i += 2 {
|
||||||
|
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||||
|
}
|
||||||
|
if tcpLength%2 == 1 {
|
||||||
|
sum += uint32(tcpLayer[tcpLength-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for sum > 0xFFFF {
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
}
|
||||||
|
|
||||||
|
checksum := ^uint16(sum)
|
||||||
|
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
||||||
transport := d.decoded[1]
|
transport := d.decoded[1]
|
||||||
switch transport {
|
switch transport {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
|
||||||
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
|
||||||
b.Helper()
|
|
||||||
|
|
||||||
ipv4 := &layers.IPv4{
|
|
||||||
TTL: 64,
|
|
||||||
Version: 4,
|
|
||||||
SrcIP: srcIP,
|
|
||||||
DstIP: dstIP,
|
|
||||||
Protocol: layers.IPProtocolTCP,
|
|
||||||
}
|
|
||||||
|
|
||||||
tcp := &layers.TCP{
|
|
||||||
SrcPort: layers.TCPPort(srcPort),
|
|
||||||
DstPort: layers.TCPPort(dstPort),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set TCP flags
|
|
||||||
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
|
|
||||||
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
|
|
||||||
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
|
|
||||||
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
|
|
||||||
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
|
|
||||||
|
|
||||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
|
||||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
|
||||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkRouteACLs(b *testing.B) {
|
func BenchmarkRouteACLs(b *testing.B) {
|
||||||
manager := setupRoutedManager(b, "10.10.0.100/16")
|
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||||
|
|
||||||
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
|
||||||
|
// This shows the overhead difference between the common case (non-SYN packets, fast path)
|
||||||
|
// and the rare case (SYN packets that need clamping, expensive path).
|
||||||
|
func BenchmarkMSSClamping(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
description string
|
||||||
|
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||||
|
frequency string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "syn_needs_clamp",
|
||||||
|
description: "SYN packet needing MSS clamping",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||||
|
},
|
||||||
|
frequency: "~0.1% of traffic - EXPENSIVE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "syn_no_clamp_needed",
|
||||||
|
description: "SYN packet with already-small MSS",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
|
||||||
|
},
|
||||||
|
frequency: "~0.05% of traffic",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp_ack",
|
||||||
|
description: "Non-SYN TCP packet (ACK, data transfer)",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
frequency: "~60-70% of traffic - FAST PATH",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp_psh_ack",
|
||||||
|
description: "TCP data packet (PSH+ACK)",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
frequency: "~10-20% of traffic - FAST PATH",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp",
|
||||||
|
description: "UDP packet",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
frequency: "~20-30% of traffic - FAST PATH",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
manager.mssClampEnabled = true
|
||||||
|
manager.mssClampValue = 1240
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
packet := sc.genPacket(b, srcIP, dstIP)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.filterOutbound(packet, len(packet))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
|
||||||
|
// for the common case (non-SYN TCP packets).
|
||||||
|
func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
enabled bool
|
||||||
|
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "disabled_tcp_ack",
|
||||||
|
enabled: false,
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_tcp_ack",
|
||||||
|
enabled: true,
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "disabled_syn_needs_clamp",
|
||||||
|
enabled: false,
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enabled_syn_needs_clamp",
|
||||||
|
enabled: true,
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
manager.mssClampEnabled = sc.enabled
|
||||||
|
if sc.enabled {
|
||||||
|
manager.mssClampValue = 1240
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
packet := sc.genPacket(b, srcIP, dstIP)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.filterOutbound(packet, len(packet))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
|
||||||
|
func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tcp_ack_fast_path",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "syn_needs_clamp",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp_fast_path",
|
||||||
|
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||||
|
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
manager.mssClampEnabled = true
|
||||||
|
manager.mssClampValue = 1240
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.64.0.2")
|
||||||
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
packet := sc.genPacket(b, srcIP, dstIP)
|
||||||
|
|
||||||
|
b.ReportAllocs()
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
manager.filterOutbound(packet, len(packet))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
|
||||||
|
b.Helper()
|
||||||
|
|
||||||
|
ip := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
IHL: 5,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
Seq: 1000,
|
||||||
|
Window: 65535,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
FixLengths: true,
|
||||||
|
ComputeChecksums: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -17,6 +18,7 @@ import (
|
|||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger)
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger)
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger)
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
@@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger)
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger)
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface, false, flowLogger)
|
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
@@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
@@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
|||||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestMSSClamping(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||||
|
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||||
|
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||||
|
|
||||||
|
err = manager.UpdateLocalIPs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcIP := net.ParseIP("100.10.0.2")
|
||||||
|
dstIP := net.ParseIP("8.8.8.8")
|
||||||
|
|
||||||
|
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
|
||||||
|
highMSS := uint16(1460)
|
||||||
|
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||||
|
|
||||||
|
manager.filterOutbound(packet, len(packet))
|
||||||
|
|
||||||
|
d := parsePacket(t, packet)
|
||||||
|
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||||
|
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||||
|
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||||
|
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||||
|
lowMSS := uint16(1200)
|
||||||
|
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
|
||||||
|
|
||||||
|
manager.filterOutbound(packet, len(packet))
|
||||||
|
|
||||||
|
d := parsePacket(t, packet)
|
||||||
|
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||||
|
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||||
|
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
|
||||||
|
highMSS := uint16(1460)
|
||||||
|
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||||
|
|
||||||
|
manager.filterOutbound(packet, len(packet))
|
||||||
|
|
||||||
|
d := parsePacket(t, packet)
|
||||||
|
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||||
|
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||||
|
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||||
|
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
|
||||||
|
|
||||||
|
manager.filterOutbound(packet, len(packet))
|
||||||
|
|
||||||
|
d := parsePacket(t, packet)
|
||||||
|
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
ipLayer := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpLayer := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
Window: 65535,
|
||||||
|
Options: []layers.TCPOption{
|
||||||
|
{
|
||||||
|
OptionType: layers.TCPOptionKindMSS,
|
||||||
|
OptionLength: 4,
|
||||||
|
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
ipLayer := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpLayer := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
ACK: true,
|
||||||
|
Window: 65535,
|
||||||
|
Options: []layers.TCPOption{
|
||||||
|
{
|
||||||
|
OptionType: layers.TCPOptionKindMSS,
|
||||||
|
OptionLength: 4,
|
||||||
|
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
ipLayer := &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocolTCP,
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpLayer := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
Window: 65535,
|
||||||
|
}
|
||||||
|
|
||||||
|
if flags&uint16(conntrack.TCPSyn) != 0 {
|
||||||
|
tcpLayer.SYN = true
|
||||||
|
}
|
||||||
|
if flags&uint16(conntrack.TCPAck) != 0 {
|
||||||
|
tcpLayer.ACK = true
|
||||||
|
}
|
||||||
|
if flags&uint16(conntrack.TCPFin) != 0 {
|
||||||
|
tcpLayer.FIN = true
|
||||||
|
}
|
||||||
|
if flags&uint16(conntrack.TCPRst) != 0 {
|
||||||
|
tcpLayer.RST = true
|
||||||
|
}
|
||||||
|
if flags&uint16(conntrack.TCPPush) != 0 {
|
||||||
|
tcpLayer.PSH = true
|
||||||
|
}
|
||||||
|
|
||||||
|
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||||
|
require.NoError(tb, err)
|
||||||
|
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|||||||
@@ -45,7 +45,7 @@ type Forwarder struct {
|
|||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
HandleLocal: false,
|
HandleLocal: false,
|
||||||
})
|
})
|
||||||
|
|
||||||
mtu, err := iface.GetDevice().MTU()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get MTU: %w", err)
|
|
||||||
}
|
|
||||||
nicID := tcpip.NICID(1)
|
nicID := tcpip.NICID(1)
|
||||||
endpoint := &endpoint{
|
endpoint := &endpoint{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
return nil, fmt.Errorf("create NIC: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ type idleConn struct {
|
|||||||
conn *udpPacketConn
|
conn *udpPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &udpForwarder{
|
f := &udpForwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
|||||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
|||||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
@@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -16,7 +17,7 @@ import (
|
|||||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
|||||||
func TestDNATMappingManagement(t *testing.T) {
|
func TestDNATMappingManagement(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) {
|
|||||||
func TestInboundPortDNAT(t *testing.T) {
|
func TestInboundPortDNAT(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
@@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) {
|
|||||||
func TestInboundPortDNATNegative(t *testing.T) {
|
func TestInboundPortDNATNegative(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false, flowLogger)
|
}, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false, flowLogger)
|
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
if !statefulMode {
|
if !statefulMode {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
err = fw.Close(nil)
|
err = fw.Close(nil)
|
||||||
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
err = fw.Close(nil)
|
err = fw.Close(nil)
|
||||||
@@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
err = fw.Close(nil)
|
err = fw.Close(nil)
|
||||||
|
|||||||
@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|||||||
@@ -506,7 +506,7 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil || e.firewall == nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
Reference in New Issue
Block a user