[client] Clamp MSS on outbound traffic (#4735)

This commit is contained in:
Viktor Liu
2025-11-04 17:18:51 +01:00
committed by GitHub
parent 679c58ce47
commit 45c25dca84
24 changed files with 804 additions and 134 deletions

View File

@@ -15,13 +15,13 @@ import (
) )
// NewFirewall creates a firewall manager instance // NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
} }
// use userspace packet filtering firewall // use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type // FWType is the type for the firewall type
type FWType int type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables // on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic // in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers // so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall // for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
if !iface.IsUserspaceBind() { if !iface.IsUserspaceBind() {
return fm, err return fm, err
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
if err != nil { if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
} }
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} }
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface) fm, err := createFW(iface, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("create firewall: %s", err) return nil, fmt.Errorf("create firewall: %s", err)
} }
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
return fm, nil return fm, nil
} }
func createFW(iface IFaceMapper) (firewall.Manager, error) { func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
switch check() { switch check() {
case IPTABLES: case IPTABLES:
log.Info("creating an iptables firewall manager") log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface) return nbiptables.Create(iface, mtu)
case NFTABLES: case NFTABLES:
log.Info("creating an nftables firewall manager") log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface) return nbnftables.Create(iface, mtu)
default: default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall") log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found") return nil, errors.New("no firewall manager found")
} }
} }
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error var errUsp error
if fm != nil { if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else { } else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
} }
if errUsp != nil { if errUsp != nil {

View File

@@ -36,7 +36,7 @@ type iFaceMapper interface {
} }
// Create iptables firewall manager // Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, fmt.Errorf("init iptables: %w", err) return nil, fmt.Errorf("init iptables: %w", err)
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient, ipv4Client: iptablesClient,
} }
m.router, err = newRouter(iptablesClient, wgIface) m.router, err = newRouter(iptablesClient, wgIface, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) return nil, fmt.Errorf("create router: %w", err)
} }
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(), WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(), UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
}, },
} }
stateManager.RegisterState(state) stateManager.RegisterState(state)

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err) require.NoError(t, err)
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
} }
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
time.Sleep(time.Second) time.Sleep(time.Second)

View File

@@ -30,17 +30,20 @@ const (
chainPOSTROUTING = "POSTROUTING" chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING" chainPREROUTING = "PREROUTING"
chainFORWARD = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT" chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE" chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR" chainRTRDR = "NETBIRD-RT-RDR"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT" routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE" routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre" jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre" jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post" jumpNatPost = "jump-nat-post"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre" markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post" markManglePost = "mark-mangle-post"
matchSet = "--match-set" matchSet = "--match-set"
@@ -48,6 +51,9 @@ const (
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
snatSuffix = "_snat" snatSuffix = "_snat"
fwdSuffix = "_fwd" fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
) )
type ruleInfo struct { type ruleInfo struct {
@@ -77,16 +83,18 @@ type router struct {
ipsetCounter *ipsetCounter ipsetCounter *ipsetCounter
wgIface iFaceMapper wgIface iFaceMapper
legacyManagement bool legacyManagement bool
mtu uint16
stateManager *statemanager.Manager stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
} }
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{ r := &router{
iptablesClient: iptablesClient, iptablesClient: iptablesClient,
rules: make(map[string][]string), rules: make(map[string][]string),
wgIface: wgIface, wgIface: wgIface,
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
} }
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} { } {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil { if err != nil {
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
{chainRTPRE, tableMangle}, {chainRTPRE, tableMangle},
{chainRTNAT, tableNat}, {chainRTNAT, tableNat},
{chainRTRDR, tableNat}, {chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} { } {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("add jump rules: %w", err) return fmt.Errorf("add jump rules: %w", err)
} }
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil return nil
} }
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
return nil return nil
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSCLAMP,
}
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *router) insertEstablishedRule(chain string) error { func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished() establishedRule := getConntrackEstablished()
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
} }
func (r *router) cleanJumpRules() error { func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
if rule, exists := r.rules[ruleKey]; exists { if rule, exists := r.rules[ruleKey]; exists {
var table, chain string var table, chain string
switch ruleKey { switch ruleKey {
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
case jumpNatPre: case jumpNatPre:
table = tableNat table = tableNat
chain = chainPREROUTING chain = chainPREROUTING
case jumpMSSClamp:
table = tableMangle
chain = chainFORWARD
default: default:
return fmt.Errorf("unknown jump rule: %s", ruleKey) return fmt.Errorf("unknown jump rule: %s", ruleKey)
} }

View File

@@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net" nbnet "github.com/netbirdio/netbird/client/net"
) )
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
assert.NoError(t, manager.Reset(), "shouldn't return error") assert.NoError(t, manager.Reset(), "shouldn't return error")
}() }()
// Now 5 rules:
// 1. established rule forward in // 1. established rule forward in
// 2. estbalished rule forward out // 2. estbalished rule forward out
// 3. jump rule to POST nat chain // 3. jump rule to POST nat chain
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 7. static return masquerade rule // 7. static return masquerade rule
// 8. mangle prerouting mark rule // 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule // 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map") // 10. jump rule to MSS clamping chain
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client") require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock) manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error") require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil)) require.NoError(t, manager.init(nil))
defer func() { defer func() {
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client") require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock) r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager") require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil)) require.NoError(t, r.init(nil))

View File

@@ -4,6 +4,7 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -11,6 +12,7 @@ type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState) mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
ipt, err := Create(s.InterfaceState, mtu)
if err != nil { if err != nil {
return fmt.Errorf("create iptables manager: %w", err) return fmt.Errorf("create iptables manager: %w", err)
} }

View File

@@ -44,7 +44,7 @@ type Manager struct {
} }
// Create nftables firewall manager // Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) { func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
m := &Manager{ m := &Manager{
rConn: &nftables.Conn{}, rConn: &nftables.Conn{},
wgIface: wgIface, wgIface: wgIface,
@@ -53,7 +53,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
var err error var err error
m.router, err = newRouter(workTable, wgIface) m.router, err = newRouter(workTable, wgIface, mtu)
if err != nil { if err != nil {
return nil, fmt.Errorf("create router: %w", err) return nil, fmt.Errorf("create router: %w", err)
} }
@@ -93,6 +93,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(), NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(), WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(), UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
}, },
}); err != nil { }); err != nil {
log.Errorf("failed to update state: %v", err) log.Errorf("failed to update state: %v", err)

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) { func TestNftablesManager(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
func TestNftablesManagerRuleOrder(t *testing.T) { func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs // This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior // We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface // just check on the local interface
manager, err := Create(mock) manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr := runIptablesSave(t) stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr) verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager") require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil)) require.NoError(t, manager.Init(nil))

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/nftables/xt" "github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -32,12 +33,16 @@ const (
chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect" chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD" chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptForwardRuleOif = "frwacceptoif"
dnatSuffix = "_dnat" dnatSuffix = "_dnat"
snatSuffix = "_snat" snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
) )
const refreshRulesMapError = "refresh rules map: %w" const refreshRulesMapError = "refresh rules map: %w"
@@ -63,9 +68,10 @@ type router struct {
wgIface iFaceMapper wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool legacyManagement bool
mtu uint16
} }
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{ r := &router{
conn: &nftables.Conn{}, conn: &nftables.Conn{},
workTable: workTable, workTable: workTable,
@@ -73,6 +79,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
rules: make(map[string]*nftables.Rule), rules: make(map[string]*nftables.Rule),
wgIface: wgIface, wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(), ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
} }
r.ipsetCounter = refcounter.New( r.ipsetCounter = refcounter.New(
@@ -220,11 +227,23 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter, Type: nftables.ChainTypeFilter,
}) })
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
// Add the single NAT rule that matches on mark // Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil { if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err) return fmt.Errorf("add single nat rule: %v", err)
} }
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil { if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err) log.Errorf("failed to add accept rules for the forward chain: %s", err)
} }
@@ -745,6 +764,83 @@ func (r *router) addPostroutingRules() error {
return nil return nil
} }
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true) sourceExp, err := r.applyNetwork(pair.Source, nil, true)

View File

@@ -17,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test" "github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
) )
const ( const (
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases { for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present // need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases { for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock) manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
}) })
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable)) require.NoError(t, r.init(workTable))
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable() defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock) r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router") require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable)) require.NoError(t, r.init(workTable))

View File

@@ -3,6 +3,7 @@ package nftables
import ( import (
"fmt" "fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -10,6 +11,7 @@ type InterfaceState struct {
NameStr string `json:"name"` NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"` WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"` UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
} }
func (i *InterfaceState) Name() string { func (i *InterfaceState) Name() string {
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
} }
func (s *ShutdownState) Cleanup() error { func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState) mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
nft, err := Create(s.InterfaceState, mtu)
if err != nil { if err != nil {
return fmt.Errorf("create nftables manager: %w", err) return fmt.Errorf("create nftables manager: %w", err)
} }

View File

@@ -1,6 +1,7 @@
package uspfilter package uspfilter
import ( import (
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"net" "net"
@@ -27,7 +28,12 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
) )
const layerTypeAll = 0 const (
layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const ( const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
@@ -36,6 +42,9 @@ const (
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
// EnvForceUserspaceRouter forces userspace routing even if native routing is available. // EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
@@ -122,6 +131,10 @@ type Manager struct {
netstackServices map[serviceKey]struct{} netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex netstackServiceMutex sync.RWMutex
mtu uint16
mssClampValue uint16
mssClampEnabled bool
} }
// decoder for packages // decoder for packages
@@ -140,16 +153,16 @@ type decoder struct {
} }
// Create userspace firewall manager constructor // Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger) return create(iface, nil, disableServerRoutes, flowLogger, mtu)
} }
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil { if nativeFirewall == nil {
return nil, errors.New("native firewall is nil") return nil, errors.New("native firewall is nil")
} }
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -157,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil return mgr, nil
} }
func parseCreateEnv() (bool, bool) { func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding bool var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error var err error
if val := os.Getenv(EnvDisableConntrack); val != "" { if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val) disableConntrack, err = strconv.ParseBool(val)
@@ -177,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err) log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
} }
} }
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
disableMSSClamping, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
}
}
return disableConntrack, enableLocalForwarding return disableConntrack, enableLocalForwarding, disableMSSClamping
} }
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv() disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{ m := &Manager{
decoders: sync.Pool{ decoders: sync.Pool{
@@ -213,13 +232,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
dnatMappings: make(map[netip.Addr]netip.Addr), dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{}, portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}), netstackServices: make(map[serviceKey]struct{}),
mtu: mtu,
} }
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
if !disableMSSClamping {
m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err) return nil, fmt.Errorf("update local IPs: %w", err)
} }
if disableConntrack { if disableConntrack {
log.Info("conntrack is disabled") log.Info("conntrack is disabled")
} else { } else {
@@ -227,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
} }
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding { if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil { if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err) log.Errorf("failed to initialize forwarder: %v", err)
} }
} }
if err := iface.SetFilter(m); err != nil { if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err) return nil, fmt.Errorf("set filter: %w", err)
} }
@@ -337,7 +357,7 @@ func (m *Manager) initForwarder() error {
return errors.New("forwarding not supported") return errors.New("forwarding not supported")
} }
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
if err != nil { if err != nil {
m.routingEnabled.Store(false) m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err) return fmt.Errorf("create forwarder: %w", err)
@@ -643,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return false return false
} }
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { switch d.decoded[1] {
return true case layers.LayerTypeUDP:
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
}
case layers.LayerTypeTCP:
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
m.clampTCPMSS(packetData, d)
}
} }
m.trackOutbound(d, srcIP, dstIP, packetData, size) m.trackOutbound(d, srcIP, dstIP, packetData, size)
@@ -691,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags return flags
} }
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
if !d.tcp.SYN {
return false
}
if len(d.tcp.Options) == 0 {
return false
}
mssOptionIndex := -1
var currentMSS uint16
for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue {
mssOptionIndex = i
break
}
}
}
if mssOptionIndex == -1 {
return false
}
ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true
}
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20
optOffset := tcpOptionsStart
for j := 0; j < mssOptionIndex; j++ {
switch d.tcp.Options[j].OptionType {
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
optOffset++
default:
optOffset += 2 + len(d.tcp.Options[j].OptionData)
}
}
mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true
}
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart
tcpLayer[16] = 0
tcpLayer[17] = 0
var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
if tcpLength%2 == 1 {
sum += uint32(tcpLayer[tcpLength-1]) << 8
}
for sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16)
}
checksum := ^uint16(sum)
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1] transport := d.decoded[1]
switch transport { switch transport {

View File

@@ -17,6 +17,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup // Create manager and basic setup
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() { b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{ manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() { defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
}) })
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
} }
} }
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
// Set TCP flags
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}
func BenchmarkRouteACLs(b *testing.B) { func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16") manager := setupRoutedManager(b, "10.10.0.100/16")
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
} }
} }
} }
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
// This shows the overhead difference between the common case (non-SYN packets, fast path)
// and the rare case (SYN packets that need clamping, expensive path).
func BenchmarkMSSClamping(b *testing.B) {
scenarios := []struct {
name string
description string
genPacket func(*testing.B, net.IP, net.IP) []byte
frequency string
}{
{
name: "syn_needs_clamp",
description: "SYN packet needing MSS clamping",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
frequency: "~0.1% of traffic - EXPENSIVE",
},
{
name: "syn_no_clamp_needed",
description: "SYN packet with already-small MSS",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
},
frequency: "~0.05% of traffic",
},
{
name: "tcp_ack",
description: "Non-SYN TCP packet (ACK, data transfer)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
frequency: "~60-70% of traffic - FAST PATH",
},
{
name: "tcp_psh_ack",
description: "TCP data packet (PSH+ACK)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
},
frequency: "~10-20% of traffic - FAST PATH",
},
{
name: "udp",
description: "UDP packet",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
frequency: "~20-30% of traffic - FAST PATH",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
// for the common case (non-SYN TCP packets).
func BenchmarkMSSClampingOverhead(b *testing.B) {
scenarios := []struct {
name string
enabled bool
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "disabled_tcp_ack",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "enabled_tcp_ack",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "disabled_syn_needs_clamp",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "enabled_syn_needs_clamp",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = sc.enabled
if sc.enabled {
manager.mssClampValue = 1240
}
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
func BenchmarkMSSClampingMemory(b *testing.B) {
scenarios := []struct {
name string
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "tcp_ack_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "syn_needs_clamp",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "udp_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
b.Helper()
ip := &layers.IPv4{
Version: 4,
IHL: 5,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Seq: 1000,
Window: 65535,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
return buf.Bytes()
}

View File

@@ -12,6 +12,7 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device" wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
}, },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, manager) require.NotNil(t, manager)
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
}, },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err) require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting()) require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager) require.NotNil(tb, manager)
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
}, },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))

View File

@@ -1,6 +1,7 @@
package uspfilter package uspfilter
import ( import (
"encoding/binary"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -17,6 +18,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Errorf("failed to create Manager: %v", err) t.Errorf("failed to create Manager: %v", err)
return return
@@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) {
} }
// creating manager instance // creating manager instance
manager, err := Create(iface, false, flowLogger) manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
if err != nil { if err != nil {
t.Fatalf("Failed to create Manager: %s", err) t.Fatalf("Failed to create Manager: %s", err)
} }
@@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
manager.udpTracker.Close() manager.udpTracker.Close()
@@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{ ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
time.Sleep(time.Second) time.Sleep(time.Second)
@@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker manager.udpTracker.Close() // Close the existing tracker
@@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
} }
manager, err := Create(ifaceMock, false, flowLogger) manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
t.Cleanup(func() { t.Cleanup(func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) {
require.Equal(t, tc.expected, isAllowed, tc.desc) require.Equal(t, tc.expected, isAllowed, tc.desc)
} }
} }
func TestMSSClamping(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
err = manager.UpdateLocalIPs()
require.NoError(t, err)
srcIP := net.ParseIP("100.10.0.2")
dstIP := net.ParseIP("8.8.8.8")
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
})
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
lowMSS := uint16(1200)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
})
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
})
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
})
}
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
ACK: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Window: 65535,
}
if flags&uint16(conntrack.TCPSyn) != 0 {
tcpLayer.SYN = true
}
if flags&uint16(conntrack.TCPAck) != 0 {
tcpLayer.ACK = true
}
if flags&uint16(conntrack.TCPFin) != 0 {
tcpLayer.FIN = true
}
if flags&uint16(conntrack.TCPRst) != 0 {
tcpLayer.RST = true
}
if flags&uint16(conntrack.TCPPush) != 0 {
tcpLayer.PSH = true
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}

View File

@@ -45,7 +45,7 @@ type Forwarder struct {
netstack bool netstack bool
} }
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{ s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{ TransportProtocols: []stack.TransportProtocolFactory{
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
HandleLocal: false, HandleLocal: false,
}) })
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1) nicID := tcpip.NICID(1)
endpoint := &endpoint{ endpoint := &endpoint{
logger: logger, logger: logger,
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
} }
if err := s.CreateNIC(nicID, endpoint); err != nil { if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err) return nil, fmt.Errorf("create NIC: %v", err)
} }
protoAddr := tcpip.ProtocolAddress{ protoAddr := tcpip.ProtocolAddress{

View File

@@ -49,7 +49,7 @@ type idleConn struct {
conn *udpPacketConn conn *udpPacketConn
} }
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{ f := &udpForwarder{
logger: logger, logger: logger,

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
func BenchmarkDNATConcurrency(b *testing.B) { func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
func BenchmarkDNATMemoryAllocations(b *testing.B) { func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))
@@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) {
b.Run(sc.name, func(b *testing.B) { b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err) require.NoError(b, err)
defer func() { defer func() {
require.NoError(b, manager.Close(nil)) require.NoError(b, manager.Close(nil))

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
) )
@@ -16,7 +17,7 @@ import (
func TestDNATTranslationCorrectness(t *testing.T) { func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
func TestDNATMappingManagement(t *testing.T) { func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) {
func TestInboundPortDNAT(t *testing.T) { func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))
@@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) {
func TestInboundPortDNATNegative(t *testing.T) { func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{ manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil }, SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger) }, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
require.NoError(t, manager.Close(nil)) require.NoError(t, manager.Close(nil))

View File

@@ -10,6 +10,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager" fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
) )
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
}, },
} }
m, err := Create(ifaceMock, false, flowLogger) m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
if !statefulMode { if !statefulMode {

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow" "github.com/netbirdio/netbird/client/internal/netflow"
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err = fw.Close(nil) err = fw.Close(nil)
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err = fw.Close(nil) err = fw.Close(nil)
@@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes() }).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err = fw.Close(nil) err = fw.Close(nil)

View File

@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err return nil, err
} }
pf, err := uspfilter.Create(wgIface, false, flowLogger) pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
if err != nil { if err != nil {
t.Fatalf("failed to create uspfilter: %v", err) t.Fatalf("failed to create uspfilter: %v", err)
return nil, err return nil, err

View File

@@ -506,7 +506,7 @@ func (e *Engine) createFirewall() error {
} }
var err error var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil { if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err) log.Errorf("failed creating firewall manager: %s", err)
return nil return nil