Compare commits

...

8 Commits

Author SHA1 Message Date
Krzysztof Nazarewski
ec4d09215c client/ui: more understandable connection lost messages 2025-11-05 18:08:02 +01:00
Viktor Liu
c92e6c1b5f [client] Block on all subsystems on shutdown (#4709) 2025-11-05 12:15:37 +01:00
Viktor Liu
641eb5140b [client] Allow INPUT traffic on the compat iptables filter table for nftables (#4742) 2025-11-04 21:56:53 +01:00
Viktor Liu
45c25dca84 [client] Clamp MSS on outbound traffic (#4735) 2025-11-04 17:18:51 +01:00
Viktor Liu
679c58ce47 [client] Set up networkd to ignore ip rules (#4730) 2025-11-04 17:06:35 +01:00
Pascal Fischer
719283c792 [management] update db connection lifecycle configuration (#4740) 2025-11-03 17:40:12 +01:00
dependabot[bot]
a2313a5ba4 [client] Bump github.com/quic-go/quic-go from 0.48.2 to 0.49.1 (#4621)
Bumps [github.com/quic-go/quic-go](https://github.com/quic-go/quic-go) from 0.48.2 to 0.49.1.
- [Release notes](https://github.com/quic-go/quic-go/releases)
- [Commits](https://github.com/quic-go/quic-go/compare/v0.48.2...v0.49.1)

---
updated-dependencies:
- dependency-name: github.com/quic-go/quic-go
  dependency-version: 0.49.1
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-01 15:27:22 +01:00
Zoltan Papp
8c108ccad3 [client] Extend Darwin network monitoring with wakeup detection 2025-10-31 19:19:14 +01:00
43 changed files with 1441 additions and 502 deletions

View File

@@ -10,6 +10,8 @@ import (
"path/filepath"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/kardianos/service"
"github.com/spf13/cobra"
@@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error {
svcConfig.Option["LogDirectory"] = dir
}
}
if err := configureSystemdNetworkd(); err != nil {
log.Warnf("failed to configure systemd-networkd: %v", err)
}
}
if runtime.GOOS == "windows" {
@@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{
return fmt.Errorf("uninstall service: %w", err)
}
if runtime.GOOS == "linux" {
if err := cleanupSystemdNetworkd(); err != nil {
log.Warnf("failed to cleanup systemd-networkd configuration: %v", err)
}
}
cmd.Println("NetBird service has been uninstalled")
return nil
},
@@ -245,3 +257,45 @@ func isServiceRunning() (bool, error) {
return status == service.StatusRunning, nil
}
const (
networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
# routes and policy rules managed by NetBird.
[Network]
ManageForeignRoutes=no
ManageForeignRoutingPolicyRules=no
`
)
// configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error {
parentDir := filepath.Dir(networkdConfDir)
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
return nil
}
// nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err)
}
return nil
}
// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file.
func cleanupSystemdNetworkd() error {
if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) {
return nil
}
if err := os.Remove(networkdConfFile); err != nil {
return fmt.Errorf("remove networkd configuration: %w", err)
}
return nil
}

View File

@@ -15,13 +15,13 @@ import (
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}

View File

@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
if !iface.IsUserspaceBind() {
return fm, err
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface)
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface, mtu)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface)
return nbiptables.Create(iface, mtu)
case NFTABLES:
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface)
return nbnftables.Create(iface, mtu)
default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
}
if errUsp != nil {

View File

@@ -36,7 +36,7 @@ type iFaceMapper interface {
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("init iptables: %w", err)
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(iptablesClient, wgIface)
m.router, err = newRouter(iptablesClient, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err)
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)

View File

@@ -30,17 +30,20 @@ const (
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainFORWARD = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
@@ -48,6 +51,9 @@ const (
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
type ruleInfo struct {
@@ -77,16 +83,18 @@ type router struct {
ipsetCounter *ipsetCounter
wgIface iFaceMapper
legacyManagement bool
mtu uint16
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("add jump rules: %w", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil
}
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSCLAMP,
}
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
if rule, exists := r.rules[ruleKey]; exists {
var table, chain string
switch ruleKey {
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
case jumpMSSClamp:
table = tableMangle
chain = chainFORWARD
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
}

View File

@@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
assert.NoError(t, manager.Reset(), "shouldn't return error")
}()
// Now 5 rules:
// 1. established rule forward in
// 2. estbalished rule forward out
// 3. jump rule to POST nat chain
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 7. static return masquerade rule
// 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
// 10. jump rule to MSS clamping chain
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -11,6 +12,7 @@ type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
ipt, err := Create(s.InterfaceState, mtu)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}

View File

@@ -100,6 +100,9 @@ type Manager interface {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
//
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
id []byte,
ip net.IP,

View File

@@ -29,8 +29,6 @@ const (
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
const flushError = "flush: %w"
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}
ruleStruct := &Rule{
nftRule: nftRule,
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
},
)
return m.rConn.AddRule(&nftables.Rule{
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (m *AclManager) createDefaultChains() (err error) {
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
return nil
}
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":"
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}

View File

@@ -1,11 +1,11 @@
package nftables
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"
"github.com/google/nftables"
@@ -19,13 +19,22 @@ import (
)
const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird"
// envTableName is the environment variable to override the table name
envTableName = "NB_NFTABLES_TABLE"
tableNameFilter = "filter"
chainNameInput = "INPUT"
)
func getTableName() string {
if name := os.Getenv(envTableName); name != "" {
return name
}
return tableNameNetbird
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
@@ -44,16 +53,16 @@ type Manager struct {
}
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
m := &Manager{
rConn: &nftables.Conn{},
wgIface: wgIface,
}
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface)
m.router, err = newRouter(workTable, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
@@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
err := m.aclManager.createDefaultAllowRules()
if err != nil {
return fmt.Errorf("failed to create default allow rules: %v", err)
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
return nil
@@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
@@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list tables: %w", err)
}
tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
@@ -398,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
return nil, fmt.Errorf("list of tables: %w", err)
}
tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(allowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -32,12 +33,17 @@ const (
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const refreshRulesMapError = "refresh rules map: %w"
@@ -63,9 +69,10 @@ type router struct {
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
@@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
}
r.ipsetCounter = refcounter.New(
@@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from filter table: %s", err)
}
if err := r.createContainers(); err != nil {
@@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error {
return nil
}
// Reset cleans existing nftables default forward rules from the system
// Reset cleans existing nftables filter table rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
@@ -220,11 +228,23 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
@@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error {
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
@@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error {
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
@@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
return r.acceptFilterRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
return r.acceptFilterRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
log.Debugf("added iptables forward rule: %v", rule)
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr)
}
@@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string {
}
}
func (r *router) acceptForwardRulesNftables() error {
func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
@@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error {
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
@@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error {
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
func (r *router) removeAcceptFilterRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
return r.removeAcceptFilterRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue
}
@@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
@@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -17,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
)
const (
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))

View File

@@ -3,6 +3,7 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -10,6 +11,7 @@ type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
nft, err := Create(s.InterfaceState, mtu)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net"
@@ -27,7 +28,12 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const layerTypeAll = 0
const (
layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
@@ -36,6 +42,9 @@ const (
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
@@ -122,6 +131,10 @@ type Manager struct {
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
mtu uint16
mssClampValue uint16
mssClampEnabled bool
}
// decoder for packages
@@ -140,16 +153,16 @@ type decoder struct {
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger)
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
@@ -157,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil
}
func parseCreateEnv() (bool, bool) {
var disableConntrack, enableLocalForwarding bool
func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error
if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val)
@@ -177,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
}
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
disableMSSClamping, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
}
}
return disableConntrack, enableLocalForwarding
return disableConntrack, enableLocalForwarding, disableMSSClamping
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv()
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{
decoders: sync.Pool{
@@ -213,13 +232,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
mtu: mtu,
}
m.routingEnabled.Store(false)
if !disableMSSClamping {
m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
}
if disableConntrack {
log.Info("conntrack is disabled")
} else {
@@ -227,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
}
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}
if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
}
@@ -337,7 +357,7 @@ func (m *Manager) initForwarder() error {
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
if err != nil {
m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err)
@@ -643,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return false
}
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
switch d.decoded[1] {
case layers.LayerTypeUDP:
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
}
case layers.LayerTypeTCP:
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
m.clampTCPMSS(packetData, d)
}
}
m.trackOutbound(d, srcIP, dstIP, packetData, size)
@@ -691,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
if !d.tcp.SYN {
return false
}
if len(d.tcp.Options) == 0 {
return false
}
mssOptionIndex := -1
var currentMSS uint16
for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue {
mssOptionIndex = i
break
}
}
}
if mssOptionIndex == -1 {
return false
}
ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true
}
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20
optOffset := tcpOptionsStart
for j := 0; j < mssOptionIndex; j++ {
switch d.tcp.Options[j].OptionType {
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
optOffset++
default:
optOffset += 2 + len(d.tcp.Options[j].OptionData)
}
}
mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true
}
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart
tcpLayer[16] = 0
tcpLayer[17] = 0
var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
if tcpLength%2 == 1 {
sum += uint32(tcpLayer[tcpLength-1]) << 8
}
for sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16)
}
checksum := ^uint16(sum)
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1]
switch transport {

View File

@@ -17,6 +17,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}
}
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
// Set TCP flags
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}
func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16")
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
}
}
}
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
// This shows the overhead difference between the common case (non-SYN packets, fast path)
// and the rare case (SYN packets that need clamping, expensive path).
func BenchmarkMSSClamping(b *testing.B) {
scenarios := []struct {
name string
description string
genPacket func(*testing.B, net.IP, net.IP) []byte
frequency string
}{
{
name: "syn_needs_clamp",
description: "SYN packet needing MSS clamping",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
frequency: "~0.1% of traffic - EXPENSIVE",
},
{
name: "syn_no_clamp_needed",
description: "SYN packet with already-small MSS",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
},
frequency: "~0.05% of traffic",
},
{
name: "tcp_ack",
description: "Non-SYN TCP packet (ACK, data transfer)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
frequency: "~60-70% of traffic - FAST PATH",
},
{
name: "tcp_psh_ack",
description: "TCP data packet (PSH+ACK)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
},
frequency: "~10-20% of traffic - FAST PATH",
},
{
name: "udp",
description: "UDP packet",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
frequency: "~20-30% of traffic - FAST PATH",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
// for the common case (non-SYN TCP packets).
func BenchmarkMSSClampingOverhead(b *testing.B) {
scenarios := []struct {
name string
enabled bool
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "disabled_tcp_ack",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "enabled_tcp_ack",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "disabled_syn_needs_clamp",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "enabled_syn_needs_clamp",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = sc.enabled
if sc.enabled {
manager.mssClampValue = 1240
}
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
func BenchmarkMSSClampingMemory(b *testing.B) {
scenarios := []struct {
name string
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "tcp_ack_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "syn_needs_clamp",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "udp_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
b.Helper()
ip := &layers.IPv4{
Version: 4,
IHL: 5,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Seq: 1000,
Window: 65535,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
return buf.Bytes()
}

View File

@@ -12,6 +12,7 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NotNil(t, manager)
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager)
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
@@ -17,6 +18,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface, false, flowLogger)
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close()
@@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker
@@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) {
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}
func TestMSSClamping(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
err = manager.UpdateLocalIPs()
require.NoError(t, err)
srcIP := net.ParseIP("100.10.0.2")
dstIP := net.ParseIP("8.8.8.8")
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
})
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
lowMSS := uint16(1200)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
})
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
})
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
})
}
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
ACK: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Window: 65535,
}
if flags&uint16(conntrack.TCPSyn) != 0 {
tcpLayer.SYN = true
}
if flags&uint16(conntrack.TCPAck) != 0 {
tcpLayer.ACK = true
}
if flags&uint16(conntrack.TCPFin) != 0 {
tcpLayer.FIN = true
}
if flags&uint16(conntrack.TCPRst) != 0 {
tcpLayer.RST = true
}
if flags&uint16(conntrack.TCPPush) != 0 {
tcpLayer.PSH = true
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}

View File

@@ -45,7 +45,7 @@ type Forwarder struct {
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
return nil, fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{

View File

@@ -49,7 +49,7 @@ type idleConn struct {
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -16,7 +17,7 @@ import (
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) {
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) {
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -10,6 +10,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
if !statefulMode {

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
@@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
<-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()
backOff.Reset()
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
}
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
if c.engine == nil {
return nil
}
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
}

View File

@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
shutdownWg sync.WaitGroup
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
// This is different from ServiceEnable=false from management which completely disables the DNS service.
disableSys bool
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server
func (s *DefaultServer) Stop() {
s.ctxCancel()
s.shutdownWg.Wait()
s.mux.Lock()
defer s.mux.Unlock()
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.applyHostConfig()
s.shutdownWg.Add(1)
go func() {
// persist dns state right away
defer s.shutdownWg.Done()
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}

View File

@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}
pf, err := uspfilter.Create(wgIface, false, flowLogger)
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err

View File

@@ -15,6 +15,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
@@ -134,6 +135,8 @@ func (m *Manager) Stop(ctx context.Context) error {
}
}
m.unregisterNetstackServices()
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
@@ -158,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error {
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
return fmt.Errorf("add udp firewall rule: %w", err)
}
m.fwRules = dnsRules
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
return fmt.Errorf("add tcp firewall rule: %w", err)
}
if err := m.firewall.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
m.fwRules = dnsRules
m.tcpRules = tcpRules
m.registerNetstackServices()
return nil
}
func (m *Manager) registerNetstackServices() {
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
if registrar, ok := m.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
}
}
}
func (m *Manager) unregisterNetstackServices() {
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
if registrar, ok := m.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
}
}
}
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range m.fwRules {

View File

@@ -148,6 +148,8 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex
// sshMux protects sshServer field access
sshMux sync.Mutex
config *EngineConfig
mobileDep MobileDependency
@@ -200,8 +202,10 @@ type Engine struct {
flowManager nftypes.FlowManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
wgIfaceMonitor *WGIfaceMonitor
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
}
@@ -298,17 +302,12 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil
}
e.stopDNSForwarder()
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
if e.dnsForwardMgr != nil {
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
@@ -325,10 +324,6 @@ func (e *Engine) Stop() error {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
// stop flow manager after wg interface is gone
@@ -336,8 +331,6 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
@@ -348,12 +341,52 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
func (e *Engine) calculateShutdownTimeout() time.Duration {
peerCount := len(e.peerStore.PeersPubKey())
baseTimeout := 10 * time.Second
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
timeout := baseTimeout + perPeerTimeout
maxTimeout := 30 * time.Second
if timeout > maxTimeout {
timeout = maxTimeout
}
return timeout
}
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
@@ -483,14 +516,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
e.shutdownWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
e.triggerClientRestart()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
@@ -506,7 +539,7 @@ func (e *Engine) createFirewall() error {
}
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
@@ -674,9 +707,11 @@ func (e *Engine) removeAllPeers() error {
func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey)
e.sshMux.Lock()
if !isNil(e.sshServer) {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.sshMux.Unlock()
e.connMgr.RemovePeerConn(peerKey)
@@ -878,6 +913,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
e.sshMux.Lock()
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
@@ -885,34 +921,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
e.sshMux.Unlock()
return fmt.Errorf("create ssh server: %w", err)
}
e.sshServer = server
e.sshMux.Unlock()
go func() {
// blocking
err = e.sshServer.Start()
err = server.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshMux.Lock()
e.sshServer = nil
e.sshMux.Unlock()
log.Infof("stopped SSH server")
}()
} else {
e.sshMux.Unlock()
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
} else {
e.sshMux.Lock()
if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
e.sshServer = nil
e.sshMux.Unlock()
}
return nil
}
@@ -949,7 +993,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
@@ -1125,6 +1171,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys
e.sshMux.Lock()
if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
@@ -1135,6 +1182,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
e.sshMux.Unlock()
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1377,7 +1425,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
@@ -1494,12 +1544,14 @@ func (e *Engine) close() {
e.statusRecorder.SetWgIface(nil)
}
e.sshMux.Lock()
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed stopping the SSH server: %v", err)
}
}
e.sshMux.Unlock()
if e.firewall != nil {
err := e.firewall.Close(e.stateManager)
@@ -1730,8 +1782,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
return allHealthy
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
// triggerClientRestart triggers a full client restart by cancelling the client context.
// Note: This does NOT just restart the engine - it cancels the entire client context,
// which causes the connect client's retry loop to create a completely new engine.
func (e *Engine) triggerClientRestart() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1753,7 +1807,9 @@ func (e *Engine) startNetworkMonitor() {
}
e.networkMonitor = networkmonitor.New()
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
@@ -1763,8 +1819,8 @@ func (e *Engine) startNetworkMonitor() {
return
}
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
log.Infof("Network monitor: detected network change, triggering client restart")
e.triggerClientRestart()
}()
}
@@ -1873,7 +1929,6 @@ func (e *Engine) updateDNSForwarder(
func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
e.registerDNSServices()
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
@@ -1893,34 +1948,9 @@ func (e *Engine) stopDNSForwarder() {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.unregisterDNSServices()
e.dnsForwardMgr = nil
}
func (e *Engine) registerDNSServices() {
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
}
}
}
func (e *Engine) unregisterDNSServices() {
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
}
}
}
func (e *Engine) GetNet() (*netstack.Net, error) {
e.syncMsgMux.Lock()
intf := e.wgInterface

View File

@@ -24,6 +24,7 @@ import (
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
go m.receiveACKs(ctx, flowClient)
go m.startSender(ctx)
m.shutdownWg.Add(2)
go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient)
}()
go func() {
defer m.shutdownWg.Done()
m.startSender(ctx)
}()
return nil
}
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
// Close cleans up all resources
func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if err := m.disableFlow(); err != nil {
log.Warnf("failed to disable flow manager: %v", err)
}
m.mux.Unlock()
m.shutdownWg.Wait()
}
// GetLogger returns the flow logger

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
//go:build dragonfly || freebsd || netbsd || openbsd
package networkmonitor
@@ -6,21 +6,19 @@ import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
fd, err := prepareFd()
if err != nil {
return fmt.Errorf("open routing socket: %v", err)
}
defer func() {
err := unix.Close(fd)
if err != nil && !errors.Is(err, unix.EBADF) {
@@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
}
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
return routeCheck(ctx, fd, nexthopv4, nexthopv6)
}

View File

@@ -0,0 +1,92 @@
//go:build dragonfly || freebsd || netbsd || openbsd || darwin
package networkmonitor
import (
"context"
"errors"
"fmt"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
func prepareFd() (int, error) {
return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
}
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
buf := make([]byte, 2048)
n, err := unix.Read(fd, buf)
if err != nil {
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
}
continue
}
if n < unix.SizeofRtMsghdr {
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
continue
}
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
switch msg.Type {
// handle route changes
case unix.RTM_ADD, syscall.RTM_DELETE:
route, err := parseRouteMessage(buf[:n])
if err != nil {
log.Debugf("Network monitor: error parsing routing message: %v", err)
continue
}
if route.Dst.Bits() != 0 {
continue
}
intf := "<nil>"
if route.Interface != nil {
intf = route.Interface.Name
}
switch msg.Type {
case unix.RTM_ADD:
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
return nil
case unix.RTM_DELETE:
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
return nil
}
}
}
}
}
}
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
if err != nil {
return nil, fmt.Errorf("parse RIB: %v", err)
}
if len(msgs) != 1 {
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
}
msg, ok := msgs[0].(*route.RouteMessage)
if !ok {
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
}
return systemops.MsgToRoute(msg)
}

View File

@@ -0,0 +1,149 @@
//go:build darwin && !ios
package networkmonitor
import (
"context"
"errors"
"fmt"
"hash/fnv"
"os/exec"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
)
// todo: refactor to not use static functions
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
fd, err := prepareFd()
if err != nil {
return fmt.Errorf("open routing socket: %v", err)
}
defer func() {
if err := unix.Close(fd); err != nil {
if !errors.Is(err, unix.EBADF) {
log.Warnf("Network monitor: failed to close routing socket: %v", err)
}
}
}()
routeChanged := make(chan struct{})
go func() {
_ = routeCheck(ctx, fd, nexthopv4, nexthopv6)
close(routeChanged)
}()
wakeUp := make(chan struct{})
go func() {
wakeUpListen(ctx)
close(wakeUp)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-routeChanged:
if ctx.Err() != nil {
return ctx.Err()
}
log.Infof("route change detected")
return nil
case <-wakeUp:
if ctx.Err() != nil {
return ctx.Err()
}
log.Infof("wakeup detected")
return nil
}
}
func wakeUpListen(ctx context.Context) {
log.Infof("start to watch for system wakeups")
var (
initialHash uint32
err error
)
// Keep retrying until initial sysctl succeeds or context is canceled
for {
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
default:
initialHash, err = readSleepTimeHash()
if err != nil {
log.Errorf("failed to detect initial sleep time: %v", err)
select {
case <-ctx.Done():
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
return
case <-time.After(3 * time.Second):
continue
}
}
log.Debugf("initial wakeup hash: %d", initialHash)
break
}
break
}
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Info("context canceled, stopping wakeUpListen")
return
case <-ticker.C:
newHash, err := readSleepTimeHash()
if err != nil {
log.Errorf("failed to read sleep time hash: %v", err)
continue
}
if newHash == initialHash {
log.Tracef("no wakeup detected")
continue
}
upOut, err := exec.Command("uptime").Output()
if err != nil {
log.Errorf("failed to run uptime command: %v", err)
upOut = []byte("unknown")
}
log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut)
return
}
}
}
func readSleepTimeHash() (uint32, error) {
cmd := exec.Command("sysctl", "kern.sleeptime")
out, err := cmd.Output()
if err != nil {
return 0, fmt.Errorf("failed to run sysctl: %w", err)
}
h, err := hash(out)
if err != nil {
return 0, fmt.Errorf("failed to compute hash: %w", err)
}
return h, nil
}
func hash(data []byte) (uint32, error) {
hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher
if _, err := hasher.Write(data); err != nil {
return 0, err
}
return hasher.Sum32(), nil
}

View File

@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
event := make(chan struct{}, 1)
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
log.Infof("start watching for network changes")
// debounce changes
timer := time.NewTimer(0)
timer.Stop()

View File

@@ -19,11 +19,10 @@ type SRWatcher struct {
signalClient chNotifier
relayManager chNotifier
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
cancelIceMonitor context.CancelFunc
}

View File

@@ -81,6 +81,7 @@ type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
shutdownWg sync.WaitGroup
clientNetworks map[route.HAUniqueID]*client.Watcher
routeSelector *routeselector.RouteSelector
serverRouter *server.Router
@@ -283,6 +284,7 @@ func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop()
m.shutdownWg.Wait()
if m.serverRouter != nil {
m.serverRouter.CleanUp()
}
@@ -485,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
}
clientNetworkWatcher := client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
}
@@ -527,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
}
clientNetworkWatcher = client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
}
update := client.RoutesUpdate{
UpdateSerial: updateSerial,

View File

@@ -9,8 +9,6 @@ import (
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route"
)
@@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
defer rs.mu.RUnlock()
if rs.deselectAll {
log.Debugf("Route %s not selected (deselect all)", routeID)
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected)
return isSelected
}

View File

@@ -280,6 +280,7 @@ type serviceClient struct {
blockLANAccess bool
connected bool
connectivityLost bool
update *version.Update
daemonVersion string
updateIndicationLock sync.Mutex
@@ -703,21 +704,41 @@ func (s *serviceClient) menuDownClick() error {
return nil
}
func (s *serviceClient) onStatusFailed(err error) {
log.Errorf("get client daemon service status: %v", err)
if s.connected {
s.connectivityLost = true
s.app.SendNotification(fyne.NewNotification(
"Warning",
"NetBird client service controls were interrupted.\n"+
"They should restore automatically.",
))
}
s.setDisconnectedStatus()
}
func (s *serviceClient) onStatusSucceeded() {
if s.connectivityLost {
s.connectivityLost = false
s.app.SendNotification(fyne.NewNotification(
"Info",
"NetBird client service controls were restored.",
))
}
}
func (s *serviceClient) updateStatus() error {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
s.onStatusFailed(err)
return err
}
err = backoff.Retry(func() error {
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
if s.connected {
s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
}
s.setDisconnectedStatus()
s.onStatusFailed(err)
return err
}
s.onStatusSucceeded()
s.updateIndicationLock.Lock()
defer s.updateIndicationLock.Unlock()

4
go.mod
View File

@@ -76,7 +76,7 @@ require (
github.com/pion/transport/v3 v3.0.7
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.22.0
github.com/quic-go/quic-go v0.48.2
github.com/quic-go/quic-go v0.49.1
github.com/redis/go-redis/v9 v9.7.3
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
@@ -241,7 +241,7 @@ require (
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
go.opentelemetry.io/otel/trace v1.35.0 // indirect
go.uber.org/mock v0.4.0 // indirect
go.uber.org/mock v0.5.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.27.0 // indirect

8
go.sum
View File

@@ -590,8 +590,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
@@ -749,8 +749,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=

View File

@@ -7,6 +7,7 @@ import (
"path/filepath"
"runtime"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
@@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e
return nil, err
}
if storeEngine == types.SqliteStoreEngine {
sqlDB.SetMaxOpenConns(1)
} else {
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
if err != nil {
conns = runtime.NumCPU()
}
sqlDB.SetMaxOpenConns(conns)
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
if err != nil {
conns = runtime.NumCPU()
}
if storeEngine == types.SqliteStoreEngine {
conns = 1
}
sqlDB.SetMaxOpenConns(conns)
sqlDB.SetMaxIdleConns(conns)
sqlDB.SetConnMaxLifetime(time.Hour)
sqlDB.SetConnMaxIdleTime(3 * time.Minute)
log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
conns, conns, time.Hour, 3*time.Minute)
return db, nil
}

View File

@@ -89,8 +89,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
}
sql.SetMaxOpenConns(conns)
sql.SetMaxIdleConns(conns)
sql.SetConnMaxLifetime(time.Hour)
sql.SetConnMaxIdleTime(3 * time.Minute)
log.WithContext(ctx).Infof("Set max open db connections to %d", conns)
log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
conns, conns, time.Hour, 3*time.Minute)
if skipMigration {
log.WithContext(ctx).Infof("skipping migration")