diff --git a/client/Dockerfile b/client/Dockerfile index b2f627409..5cd459357 100644 --- a/client/Dockerfile +++ b/client/Dockerfile @@ -4,7 +4,7 @@ # sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client . # sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest -FROM alpine:3.22.0 +FROM alpine:3.22.2 # iproute2: busybox doesn't display ip rules properly RUN apk add --no-cache \ bash \ diff --git a/client/cmd/debug.go b/client/cmd/debug.go index 18f3547ca..430012a17 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) - stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) + stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) if err != nil { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } @@ -303,12 +303,18 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error { func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string - statusResp, err := getStatus(cmd.Context()) + statusResp, err := getStatus(cmd.Context(), true) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusOutputString = nbstatus.ParseToFullDetailSummary( - nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", ""), + nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil, "", profName), ) } return statusOutputString diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 075ead44e..2a87e538d 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -10,6 +10,8 @@ import ( "path/filepath" "runtime" + log "github.com/sirupsen/logrus" + "github.com/kardianos/service" "github.com/spf13/cobra" @@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error { svcConfig.Option["LogDirectory"] = dir } } + + if err := configureSystemdNetworkd(); err != nil { + log.Warnf("failed to configure systemd-networkd: %v", err) + } } if runtime.GOOS == "windows" { @@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{ return fmt.Errorf("uninstall service: %w", err) } + if runtime.GOOS == "linux" { + if err := cleanupSystemdNetworkd(); err != nil { + log.Warnf("failed to cleanup systemd-networkd configuration: %v", err) + } + } + cmd.Println("NetBird service has been uninstalled") return nil }, @@ -245,3 +257,45 @@ func isServiceRunning() (bool, error) { return status == service.StatusRunning, nil } + +const ( + networkdConfDir = "/etc/systemd/networkd.conf.d" + networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" + networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing +# routes and policy rules managed by NetBird. + +[Network] +ManageForeignRoutes=no +ManageForeignRoutingPolicyRules=no +` +) + +// configureSystemdNetworkd creates a drop-in configuration file to prevent +// systemd-networkd from removing NetBird's routes and policy rules. +func configureSystemdNetworkd() error { + parentDir := filepath.Dir(networkdConfDir) + if _, err := os.Stat(parentDir); os.IsNotExist(err) { + log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration") + return nil + } + + // nolint:gosec // standard networkd permissions + if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { + return fmt.Errorf("write networkd configuration: %w", err) + } + + return nil +} + +// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file. +func cleanupSystemdNetworkd() error { + if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) { + return nil + } + + if err := os.Remove(networkdConfFile); err != nil { + return fmt.Errorf("remove networkd configuration: %w", err) + } + + return nil +} diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..6e57ceb89 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx) + resp, err := getStatus(ctx, false) if err != nil { return err } @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } diff --git a/client/firewall/create.go b/client/firewall/create.go index 7b265e1d1..24f12bc6d 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -15,13 +15,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index aa2f0d4d1..12dcaee8a 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) if !iface.IsUserspaceBind() { return fm, err @@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) + return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { - fm, err := createFW(iface) +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { + fm, err := createFW(iface, mtu) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) } @@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, return fm, nil } -func createFW(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) { switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") - return nbiptables.Create(iface) + return nbiptables.Create(iface, mtu) case NFTABLES: log.Info("creating an nftables firewall manager") - return nbnftables.Create(iface) + return nbnftables.Create(iface, mtu) default: log.Info("no firewall manager found, trying to use userspace packet filtering firewall") return nil, errors.New("no firewall manager found") } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } else { - fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) } if errUsp != nil { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 32103b7ec..2563a9052 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -36,7 +36,7 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("init iptables: %w", err) @@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, } stateManager.RegisterState(state) @@ -260,7 +261,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -268,7 +269,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index a5cc62feb..6b5401e2b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 343f5e05e..305b0bf28 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -30,17 +30,20 @@ const ( chainPOSTROUTING = "POSTROUTING" chainPREROUTING = "PREROUTING" + chainFORWARD = "FORWARD" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" matchSet = "--match-set" @@ -48,6 +51,9 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" fwdSuffix = "_fwd" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) type ruleInfo struct { @@ -77,16 +83,18 @@ type router struct { ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + mtu uint16 stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState } -func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, + mtu: mtu, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { @@ -416,6 +425,7 @@ func (r *router) createContainers() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) @@ -438,6 +448,10 @@ func (r *router) createContainers() error { return fmt.Errorf("add jump rules: %w", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + return nil } @@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + // Add jump rule from FORWARD chain in mangle table to our custom chain + jumpRule := []string{ + "-j", chainRTMSSCLAMP, + } + if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil { + return fmt.Errorf("add jump to MSS clamp chain: %w", err) + } + r.rules[jumpMSSClamp] = jumpRule + + ruleOut := []string{ + "-o", r.wgIface.Name(), + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mss), + } + if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil { + return fmt.Errorf("add outbound MSS clamp rule: %w", err) + } + r.rules["mss-clamp-out"] = ruleOut + + return nil +} + func (r *router) insertEstablishedRule(chain string) error { establishedRule := getConntrackEstablished() @@ -558,7 +601,7 @@ func (r *router) addJumpRules() error { } func (r *router) cleanJumpRules() error { - for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { + for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} { if rule, exists := r.rules[ruleKey]; exists { var table, chain string switch ruleKey { @@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error { case jumpNatPre: table = tableNat chain = chainPREROUTING + case jumpMSSClamp: + table = tableMangle + chain = chainFORWARD default: return fmt.Errorf("unknown jump rule: %s", ruleKey) } @@ -880,7 +926,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) @@ -913,7 +959,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol return nil } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 3490c5dad..6707573be 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, manager.init(nil)) @@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { assert.NoError(t, manager.Reset(), "shouldn't return error") }() - // Now 5 rules: // 1. established rule forward in // 2. estbalished rule forward out // 3. jump rule to POST nat chain @@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { // 7. static return masquerade rule // 8. mangle prerouting mark rule // 9. mangle postrouting mark rule - require.Len(t, manager.rules, 9, "should have created rules map") + // 10. jump rule to MSS clamping chain + // 11. MSS clamping rule for outbound traffic + require.Len(t, manager.rules, 11, "should have created rules map") exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) @@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) @@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { @@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router manager") require.NoError(t, r.init(nil)) diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 6ef159e01..c88774c1f 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -11,6 +12,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - ipt, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + ipt, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create iptables manager: %w", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 7ee33118b..72e6a5c68 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -100,6 +100,9 @@ type Manager interface { // // If comment argument is empty firewall manager should set // rule ID as comment for the rule + // + // Note: Callers should call Flush() after adding rules to ensure + // they are applied to the kernel and rule handles are refreshed. AddPeerFiltering( id []byte, ip net.IP, diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 9ff5b8c92..a9d066e2f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -29,8 +29,6 @@ const ( chainNameForwardFilter = "netbird-acl-forward-filter" chainNameManglePrerouting = "netbird-mangle-prerouting" chainNameManglePostrouting = "netbird-mangle-postrouting" - - allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) const flushError = "flush: %w" @@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { // createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - // mask - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: []byte{0, 0, 0, 0}, - Xor: []byte{0, 0, 0, 0}, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, &expr.Verdict{ Kind: expr.VerdictAccept, }, @@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering( action firewall.Action, ipset *nftables.Set, ) (*Rule, error) { - ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) + ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ nftRule: r.nftRule, @@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering( } if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf(flushError, err) + return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err) } ruleStruct := &Rule{ - nftRule: nftRule, + nftRule: nftRule, + // best effort mangle rule mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset, ruleID: ruleId, @@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt }, ) - return m.rConn.AddRule(&nftables.Rule{ + nfRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: m.chainPrerouting, Exprs: preroutingExprs, UserData: userData, }) + + if err := m.rConn.Flush(); err != nil { + log.Errorf("failed to flush mangle rule %s: %v", string(userData), err) + return nil + } + + return nfRule } func (m *AclManager) createDefaultChains() (err error) { @@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro return nil } -func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { - rulesetID := ":" +func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { + rulesetID := ":" + string(proto) + ":" if sPort != nil { rulesetID += sPort.String() } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index aa016e1c2..bd19f1067 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -1,11 +1,11 @@ package nftables import ( - "bytes" "context" "fmt" "net" "net/netip" + "os" "sync" "github.com/google/nftables" @@ -19,13 +19,22 @@ import ( ) const ( - // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client tableNameNetbird = "netbird" + // envTableName is the environment variable to override the table name + envTableName = "NB_NFTABLES_TABLE" tableNameFilter = "filter" chainNameInput = "INPUT" ) +func getTableName() string { + if name := os.Getenv(envTableName); name != "" { + return name + } + return tableNameNetbird +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string @@ -44,16 +53,16 @@ type Manager struct { } // Create nftables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, } - workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} + workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} var err error - m.router, err = newRouter(workTable, wgIface) + m.router, err = newRouter(workTable, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, }); err != nil { log.Errorf("failed to update state: %v", err) @@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() - err := m.aclManager.createDefaultAllowRules() - if err != nil { - return fmt.Errorf("failed to create default allow rules: %v", err) + if err := m.aclManager.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create default allow rules: %w", err) } - - chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list of chains: %w", err) - } - - var chain *nftables.Chain - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - chain = c - break - } - } - - if chain == nil { - log.Debugf("chain INPUT not found. Skipping add allow netbird rule") - return nil - } - - rules, err := m.rConn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) - } - - if rule := m.detectAllowNetbirdRule(rules); rule != nil { - log.Debugf("allow netbird rule already exists: %v", rule) - return nil - } - - m.applyAllowNetbirdRules(chain) - - err = m.rConn.Flush() - if err != nil { - return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush allow input netbird rules: %w", err) } return nil @@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.resetNetbirdInputRules(); err != nil { - return fmt.Errorf("reset netbird input rules: %v", err) - } - if err := m.router.Reset(); err != nil { return fmt.Errorf("reset router: %v", err) } @@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nil } -func (m *Manager) resetNetbirdInputRules() error { - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list chains: %w", err) - } - - m.deleteNetbirdInputRules(chains) - - return nil -} - -func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - rules, err := m.rConn.GetRules(c.Table, c) - if err != nil { - log.Errorf("get rules for chain %q: %v", c.Name, err) - continue - } - - m.deleteMatchingRules(rules) - } - } -} - -func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } -} - func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } @@ -376,7 +315,7 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -384,7 +323,7 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { m.mutex.Lock() defer m.mutex.Unlock() @@ -398,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) { return nil, fmt.Errorf("list of tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } -func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { - rule := &nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, - UserData: []byte(allowNetbirdInputRuleID), - } - _ = m.rConn.InsertRule(rule) -} - -func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { - ifName := ifname(m.wgIface.Name()) - for _, rule := range existedRules { - if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { - if len(rule.Exprs) < 4 { - if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { - continue - } - if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { - continue - } - return rule - } - } - } - return nil -} - func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { rule := &nftables.Rule{ Table: table, diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index c7f05dcb7..adec802c8 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) { func TestNftablesManagerRuleOrder(t *testing.T) { // This test verifies rule insertion order in nftables peer ACLs // We add accept rule first, then deny rule to test ordering behavior - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr := runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err, "failed to create manager") require.NoError(t, manager.Init(nil)) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 0c091da96..6192c92aa 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -16,6 +16,7 @@ import ( "github.com/google/nftables/xt" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -32,12 +33,17 @@ const ( chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" chainNameForward = "FORWARD" + chainNameMangleForward = "netbird-mangle-forward" userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" + userDataAcceptInputRule = "inputaccept" dnatSuffix = "_dnat" snatSuffix = "_snat" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) const refreshRulesMapError = "refresh rules map: %w" @@ -63,9 +69,10 @@ type router struct { wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool + mtu uint16 } -func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { +func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ conn: &nftables.Conn{}, workTable: workTable, @@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) rules: make(map[string]*nftables.Rule), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), + mtu: mtu, } r.ipsetCounter = refcounter.New( @@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) func (r *router) init(workTable *nftables.Table) error { r.workTable = workTable - if err := r.removeAcceptForwardRules(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + if err := r.removeAcceptFilterRules(); err != nil { + log.Errorf("failed to clean up rules from filter table: %s", err) } if err := r.createContainers(); err != nil { @@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error { return nil } -// Reset cleans existing nftables default forward rules from the system +// Reset cleans existing nftables filter table rules from the system func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() var merr *multierror.Error - if err := r.removeAcceptForwardRules(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) + if err := r.removeAcceptFilterRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } if err := r.removeNatPreroutingRules(); err != nil { @@ -220,11 +228,23 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) + r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameMangleForward, + Table: r.workTable, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, + }) + // Add the single NAT rule that matches on mark if err := r.addPostroutingRules(); err != nil { return fmt.Errorf("add single nat rule: %v", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + if err := r.acceptForwardRules(); err != nil { log.Errorf("failed to add accept rules for the forward chain: %s", err) } @@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + exprsOut := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 13, + Len: 1, + }, + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 1, + Mask: []byte{0x02}, + Xor: []byte{0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0x00}, + }, + &expr.Counter{}, + &expr.Exthdr{ + DestRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + &expr.Cmp{ + Op: expr.CmpOpGt, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Exthdr{ + SourceRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameMangleForward], + Exprs: exprsOut, + }) + + return nil +} + // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { sourceExp, err := r.applyNetwork(pair.Source, nil, true) @@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +// This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") @@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error { fw := "iptables" defer func() { - log.Debugf("Used %s to add accept forward rules", fw) + log.Debugf("Used %s to add accept forward and input rules", fw) }() // Try iptables first and fallback to nftables if iptables is not available @@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptForwardRulesNftables() + return r.acceptFilterRulesNftables() } - return r.acceptForwardRulesIptables(ipt) + return r.acceptFilterRulesIptables(ipt) } -func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) } else { - log.Debugf("added iptables rule: %v", rule) + log.Debugf("added iptables forward rule: %v", rule) } } + inputRule := r.getAcceptInputRule() + if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + } else { + log.Debugf("added iptables input rule: %v", inputRule) + } + return nberrors.FormatErrorOrNil(merr) } @@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string { } } -func (r *router) acceptForwardRulesNftables() error { +func (r *router) getAcceptInputRule() []string { + return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} +} + +func (r *router) acceptFilterRulesNftables() error { intf := ifname(r.wgIface.Name()) - // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ @@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error { }, } - // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error { Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } - r.conn.InsertRule(oifRule) + inputRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: chainNameInput, + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptInputRule), + } + r.conn.InsertRule(inputRule) + return nil } -func (r *router) removeAcceptForwardRules() error { +func (r *router) removeAcceptFilterRules() error { if r.filterTable == nil { return nil } - // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptForwardRulesNftables() + return r.removeAcceptFilterRulesNftables() } - return r.removeAcceptForwardRulesIptables(ipt) + return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptForwardRulesNftables() error { +func (r *router) removeAcceptFilterRulesNftables() error { chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + if chain.Table.Name != r.filterTable.Name { + continue + } + + if chain.Name != chainNameForward && chain.Name != chainNameInput { continue } @@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error { for _, rule := range rules { if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("delete rule: %v", err) } @@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error { return nil } -func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) } } + inputRule := r.getAcceptInputRule() + if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + } + return nberrors.FormatErrorOrNil(merr) } @@ -1350,7 +1490,7 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) @@ -1426,7 +1566,7 @@ func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol return nil } -// RemoveInboundDNAT removes inbound DNAT rule +// RemoveInboundDNAT removes an inbound DNAT rule. func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { if err := r.refreshRulesMap(); err != nil { return fmt.Errorf(refreshRulesMapError, err) diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 4fdbf3505..3531b014b 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -17,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" ) const ( @@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { // need fw manager to init both acl mgr and router for all chains to be present - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) @@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index f805623d6..48b7b3741 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -3,6 +3,7 @@ package nftables import ( "fmt" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -10,6 +11,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - nft, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + nft, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create nftables manager: %w", err) } diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index e81042979..4e22bde3f 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "errors" "fmt" "net" @@ -27,7 +28,12 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -const layerTypeAll = 0 +const ( + layerTypeAll = 0 + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 +) // serviceKey represents a protocol/port combination for netstack service registry type serviceKey struct { @@ -42,6 +48,9 @@ const ( // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" + // EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic. + EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING" + // EnvForceUserspaceRouter forces userspace routing even if native routing is available. EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" @@ -116,14 +125,16 @@ type Manager struct { dnatMutex sync.RWMutex dnatBiMap *biDNATMap - // Port-specific DNAT for SSH redirection portDNATEnabled atomic.Bool - portDNATMap *portDNATMap + portDNATRules []portDNATRule portDNATMutex sync.RWMutex - portNATTracker *portNATTracker netstackServices map[serviceKey]struct{} netstackServiceMutex sync.RWMutex + + mtu uint16 + mssClampValue uint16 + mssClampEnabled bool } // decoder for packages @@ -137,19 +148,21 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor -func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - return create(iface, nil, disableServerRoutes, flowLogger) +func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + return create(iface, nil, disableServerRoutes, flowLogger, mtu) } -func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { if nativeFirewall == nil { return nil, errors.New("native firewall is nil") } - mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) + mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } @@ -157,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall. return mgr, nil } -func parseCreateEnv() (bool, bool) { - var disableConntrack, enableLocalForwarding bool +func parseCreateEnv() (bool, bool, bool) { + var disableConntrack, enableLocalForwarding, disableMSSClamping bool var err error if val := os.Getenv(EnvDisableConntrack); val != "" { disableConntrack, err = strconv.ParseBool(val) @@ -177,12 +190,18 @@ func parseCreateEnv() (bool, bool) { log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err) } } + if val := os.Getenv(EnvDisableMSSClamping); val != "" { + disableMSSClamping, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err) + } + } - return disableConntrack, enableLocalForwarding + return disableConntrack, enableLocalForwarding, disableMSSClamping } -func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - disableConntrack, enableLocalForwarding := parseCreateEnv() +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv() m := &Manager{ decoders: sync.Pool{ @@ -211,16 +230,19 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), - portDNATMap: &portDNATMap{rules: make([]portDNATRule, 0)}, - portNATTracker: newPortNATTracker(), + portDNATRules: []portDNATRule{}, netstackServices: make(map[serviceKey]struct{}), + mtu: mtu, } m.routingEnabled.Store(false) + if !disableMSSClamping { + m.mssClampEnabled = true + m.mssClampValue = mtu - ipTCPHeaderMinSize + } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) } - if disableConntrack { log.Info("conntrack is disabled") } else { @@ -228,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) } - - // netstack needs the forwarder for local traffic if m.netstack && m.localForwarding { if err := m.initForwarder(); err != nil { log.Errorf("failed to initialize forwarder: %v", err) } } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } @@ -338,7 +357,7 @@ func (m *Manager) initForwarder() error { return errors.New("forwarding not supported") } - forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) + forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu) if err != nil { m.routingEnabled.Store(false) return fmt.Errorf("create forwarder: %w", err) @@ -351,22 +370,18 @@ func (m *Manager) initForwarder() error { return nil } -// Init initializes the firewall manager with state manager. func (m *Manager) Init(*statemanager.Manager) error { return nil } -// IsServerRouteSupported returns whether server routes are supported. func (m *Manager) IsServerRouteSupported() bool { return true } -// IsStateful returns whether the firewall manager tracks connection state. func (m *Manager) IsStateful() bool { return m.stateful } -// AddNatRule adds a routing firewall rule for NAT translation. func (m *Manager) AddNatRule(pair firewall.RouterPair) error { if m.nativeRouter.Load() && m.nativeFirewall != nil { return m.nativeFirewall.AddNatRule(pair) @@ -648,13 +663,21 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return false } - if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { - return true + switch d.decoded[1] { + case layers.LayerTypeUDP: + if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { + return true + } + case layers.LayerTypeTCP: + // Clamp MSS on all TCP SYN packets, including those from local IPs. + // SNATed routed traffic may appear as local IP but still requires clamping. + if m.mssClampEnabled { + m.clampTCPMSS(packetData, d) + } } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) - m.translateOutboundPortReverse(packetData, d) return false } @@ -697,14 +720,117 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation. +// Both sides advertise their MSS during connection establishment, so we need to clamp both. +func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { + if !d.tcp.SYN { + return false + } + if len(d.tcp.Options) == 0 { + return false + } + + mssOptionIndex := -1 + var currentMSS uint16 + for i, opt := range d.tcp.Options { + if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { + currentMSS = binary.BigEndian.Uint16(opt.OptionData) + if currentMSS > m.mssClampValue { + mssOptionIndex = i + break + } + } + } + + if mssOptionIndex == -1 { + return false + } + + ipHeaderSize := int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + + if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { + return false + } + + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + return true +} + +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { + tcpHeaderStart := ipHeaderSize + tcpOptionsStart := tcpHeaderStart + 20 + + optOffset := tcpOptionsStart + for j := 0; j < mssOptionIndex; j++ { + switch d.tcp.Options[j].OptionType { + case layers.TCPOptionKindEndList, layers.TCPOptionKindNop: + optOffset++ + default: + optOffset += 2 + len(d.tcp.Options[j].OptionData) + } + } + + mssValueOffset := optOffset + 2 + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + + m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) + return true +} + +func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) { + tcpLayer := packetData[tcpHeaderStart:] + tcpLength := len(packetData) - tcpHeaderStart + + tcpLayer[16] = 0 + tcpLayer[17] = 0 + + var pseudoSum uint32 + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + + var sum uint32 = pseudoSum + for i := 0; i < tcpLength-1; i += 2 { + sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) + } + if tcpLength%2 == 1 { + sum += uint32(tcpLayer[tcpLength-1]) << 8 + } + + for sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + checksum := ^uint16(sum) + binary.BigEndian.PutUint16(tcpLayer[16:18], checksum) +} + +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -714,13 +840,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -782,10 +910,11 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } - if translated := m.translateInboundPortDNAT(packetData, d); translated { + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { // Re-decode after port DNAT translation to update port information if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after port DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) @@ -794,7 +923,7 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 0cffcc1a7..5a2d0410f 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -17,6 +17,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { } } -// generateTCPPacketWithFlags creates a TCP packet with specific flags -func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { - b.Helper() - - ipv4 := &layers.IPv4{ - TTL: 64, - Version: 4, - SrcIP: srcIP, - DstIP: dstIP, - Protocol: layers.IPProtocolTCP, - } - - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - - // Set TCP flags - tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 - tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 - tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 - tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 - tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 - - require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) - return buf.Bytes() -} - func BenchmarkRouteACLs(b *testing.B) { manager := setupRoutedManager(b, "10.10.0.100/16") @@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) { } } } + +// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound. +// This shows the overhead difference between the common case (non-SYN packets, fast path) +// and the rare case (SYN packets that need clamping, expensive path). +func BenchmarkMSSClamping(b *testing.B) { + scenarios := []struct { + name string + description string + genPacket func(*testing.B, net.IP, net.IP) []byte + frequency string + }{ + { + name: "syn_needs_clamp", + description: "SYN packet needing MSS clamping", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + frequency: "~0.1% of traffic - EXPENSIVE", + }, + { + name: "syn_no_clamp_needed", + description: "SYN packet with already-small MSS", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200) + }, + frequency: "~0.05% of traffic", + }, + { + name: "tcp_ack", + description: "Non-SYN TCP packet (ACK, data transfer)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + frequency: "~60-70% of traffic - FAST PATH", + }, + { + name: "tcp_psh_ack", + description: "TCP data packet (PSH+ACK)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + frequency: "~10-20% of traffic - FAST PATH", + }, + { + name: "udp", + description: "UDP packet", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + frequency: "~20-30% of traffic - FAST PATH", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled +// for the common case (non-SYN TCP packets). +func BenchmarkMSSClampingOverhead(b *testing.B) { + scenarios := []struct { + name string + enabled bool + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "disabled_tcp_ack", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "enabled_tcp_ack", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "disabled_syn_needs_clamp", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "enabled_syn_needs_clamp", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = sc.enabled + if sc.enabled { + manager.mssClampValue = 1240 + } + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases +func BenchmarkMSSClampingMemory(b *testing.B) { + scenarios := []struct { + name string + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "tcp_ack_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "syn_needs_clamp", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "udp_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte { + b.Helper() + + ip := &layers.IPv4{ + Version: 4, + IHL: 5, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Seq: 1000, + Window: 65535, + } + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ip)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{}))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index 73f3face8..eb5aa3343 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -12,6 +12,7 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) require.NotNil(t, manager) @@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(tb, err) require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) @@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 3eebba59c..120a9f418 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "fmt" "net" "net/netip" @@ -17,6 +18,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nbiface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" @@ -67,7 +69,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -87,7 +89,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -120,7 +122,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -216,7 +218,7 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -266,7 +268,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -305,7 +307,7 @@ func TestNotMatchByIP(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -368,7 +370,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface, false, flowLogger) + manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -414,7 +416,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() @@ -496,7 +498,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) time.Sleep(time.Second) @@ -523,7 +525,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() // Close the existing tracker @@ -730,7 +732,7 @@ func TestUpdateSetMerge(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -816,7 +818,7 @@ func TestUpdateSetDeduplication(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -925,6 +927,195 @@ func TestUpdateSetDeduplication(t *testing.T) { } } +func TestMSSClamping(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, 1280) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") + expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) + require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + srcIP := net.ParseIP("100.10.0.2") + dstIP := net.ParseIP("8.8.8.8") + + t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + }) + + t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { + lowMSS := uint16(1200) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified") + }) + + t.Run("SYN-ACK packet gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + }) + + t.Run("Non-SYN packet unchanged", func(t *testing.T) { + packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck)) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Empty(t, d.tcp.Options, "ACK packet should have no options") + }) +} + +func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + ACK: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Window: 65535, + } + + if flags&uint16(conntrack.TCPSyn) != 0 { + tcpLayer.SYN = true + } + if flags&uint16(conntrack.TCPAck) != 0 { + tcpLayer.ACK = true + } + if flags&uint16(conntrack.TCPFin) != 0 { + tcpLayer.FIN = true + } + if flags&uint16(conntrack.TCPRst) != 0 { + tcpLayer.RST = true + } + if flags&uint16(conntrack.TCPPush) != 0 { + tcpLayer.PSH = true + } + + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + func TestShouldForward(t *testing.T) { // Set up test addresses wgIP := netip.MustParseAddr("100.10.0.1") @@ -939,7 +1130,7 @@ func TestShouldForward(t *testing.T) { return wgaddr.Address{IP: wgIP, Network: netip.PrefixFrom(wgIP, 24)} } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 42a3e0800..00cb3f1df 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -45,7 +45,7 @@ type Forwarder struct { netstack bool } -func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { +func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow HandleLocal: false, }) - mtu, err := iface.GetDevice().MTU() - if err != nil { - return nil, fmt.Errorf("get MTU: %w", err) - } nicID := tcpip.NICID(1) endpoint := &endpoint{ logger: logger, @@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } if err := s.CreateNIC(nicID, endpoint); err != nil { - return nil, fmt.Errorf("failed to create NIC: %v", err) + return nil, fmt.Errorf("create NIC: %v", err) } protoAddr := tcpip.ProtocolAddress{ diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index d146de5e4..55743d975 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -49,7 +49,7 @@ type idleConn struct { conn *udpPacketConn } -func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { +func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ logger: logger, diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index bf1c6feb5..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,8 +5,7 @@ import ( "errors" "fmt" "net/netip" - "sync" - "time" + "slices" "github.com/google/gopacket" "github.com/google/gopacket/layers" @@ -21,10 +20,16 @@ var ( ) const ( - errRewriteTCPDestinationPort = "rewrite TCP destination port: %v" + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 ) -// ipv4Checksum calculates IPv4 header checksum using optimized parallel processing for performance. +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -64,7 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } -// icmpChecksum calculates ICMP checksum using parallel accumulation for high-performance processing. +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -102,116 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } -// biDNATMap maintains bidirectional DNAT mappings for efficient forward and reverse lookups. +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } -// portDNATRule represents a port-specific DNAT rule +// portDNATRule represents a port-specific DNAT rule. type portDNATRule struct { protocol gopacket.LayerType - sourcePort uint16 + origPort uint16 targetPort uint16 targetIP netip.Addr } -// portDNATMap manages port-specific DNAT rules -type portDNATMap struct { - rules []portDNATRule -} - -// ConnKey represents a connection 4-tuple for NAT tracking. -type ConnKey struct { - SrcIP netip.Addr - DstIP netip.Addr - SrcPort uint16 - DstPort uint16 -} - -// portNATConn tracks port NAT state for a specific connection. -type portNATConn struct { - rule portDNATRule - originalPort uint16 - translatedAt time.Time -} - -// portNATTracker tracks connection-specific port NAT state -type portNATTracker struct { - connections map[ConnKey]*portNATConn - mutex sync.RWMutex -} - -// newPortNATTracker creates a new port NAT tracker for stateful connection tracking. -func newPortNATTracker() *portNATTracker { - return &portNATTracker{ - connections: make(map[ConnKey]*portNATConn), - } -} - -// trackConnection tracks a connection that has port NAT applied using translated port as key. -func (t *portNATTracker) trackConnection(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, rule portDNATRule) { - t.mutex.Lock() - defer t.mutex.Unlock() - - key := ConnKey{ - SrcIP: srcIP, - DstIP: dstIP, - SrcPort: srcPort, - DstPort: rule.targetPort, - } - - t.connections[key] = &portNATConn{ - rule: rule, - originalPort: dstPort, - translatedAt: time.Now(), - } -} - -// getConnectionNAT returns NAT info for a connection if tracked, looking up by connection 4-tuple. -func (t *portNATTracker) getConnectionNAT(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) (*portNATConn, bool) { - t.mutex.RLock() - defer t.mutex.RUnlock() - - key := ConnKey{ - SrcIP: srcIP, - DstIP: dstIP, - SrcPort: srcPort, - DstPort: dstPort, - } - - conn, exists := t.connections[key] - return conn, exists -} - -// shouldApplyNAT checks if NAT should be applied to a new connection to prevent bidirectional conflicts. -func (t *portNATTracker) shouldApplyNAT(srcIP, dstIP netip.Addr, dstPort uint16) bool { - t.mutex.RLock() - defer t.mutex.RUnlock() - - for key, conn := range t.connections { - if key.SrcIP == dstIP && key.DstIP == srcIP && - conn.rule.sourcePort == dstPort && conn.originalPort == dstPort { - return false - } - } - return true -} - -// cleanupConnection removes a NAT connection based on original 4-tuple for connection cleanup. -func (t *portNATTracker) cleanupConnection(srcIP, dstIP netip.Addr, srcPort uint16) { - t.mutex.Lock() - defer t.mutex.Unlock() - - for key := range t.connections { - if key.SrcIP == srcIP && key.DstIP == dstIP && key.SrcPort == srcPort { - delete(t.connections, key) - return - } - } -} - -// newBiDNATMap creates a new bidirectional DNAT mapping structure for efficient forward/reverse lookups. +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -219,7 +129,7 @@ func newBiDNATMap() *biDNATMap { } } -// set adds a bidirectional DNAT mapping between original and translated addresses for both directions. +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original @@ -233,13 +143,13 @@ func (b *biDNATMap) delete(original netip.Addr) { } } -// getTranslated returns the translated address for a given original address from forward mapping. +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } -// getOriginal returns the original address for a given translated address from reverse mapping. +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists @@ -261,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -295,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists with fast-path optimization. +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -307,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic using reverse mapping. +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -319,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets for 1:1 IP mapping. +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -336,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -345,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic for 1:1 IP mapping. +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -362,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -371,17 +272,17 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet and updates checksums. -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { @@ -395,9 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -406,42 +307,7 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP return nil } -// rewritePacketSource replaces the source IP address in the packet and updates checksums. -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return errInvalidIPHeaderLength - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// updateTCPChecksum updates TCP checksum after IP address change using incremental update per RFC 1624. +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -454,7 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } -// updateUDPChecksum updates UDP checksum after IP address change using incremental update per RFC 1624. +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -472,7 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } -// updateICMPChecksum recalculates ICMP checksum after packet modification using full recalculation. +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -485,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 for performance. +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -536,25 +402,25 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { return m.nativeFirewall.DeleteDNATRule(rule) } -// addPortRedirection adds port redirection rule for specified target IP, protocol and ports. +// addPortRedirection adds a port redirection rule. func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() rule := portDNATRule{ protocol: protocol, - sourcePort: sourcePort, + origPort: sourcePort, targetPort: targetPort, targetIP: targetIP, } - m.portDNATMap.rules = append(m.portDNATMap.rules, rule) + m.portDNATRules = append(m.portDNATRules, rule) m.portDNATEnabled.Store(true) return nil } -// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services on specific ports. +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType switch protocol { @@ -569,27 +435,23 @@ func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protoco return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) } -// removePortRedirection removes port redirection rule for specified target IP, protocol and ports. +// removePortRedirection removes a port redirection rule. func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { m.portDNATMutex.Lock() defer m.portDNATMutex.Unlock() - var filteredRules []portDNATRule - for _, rule := range m.portDNATMap.rules { - if !(rule.protocol == protocol && rule.sourcePort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0) { - filteredRules = append(filteredRules, rule) - } - } - m.portDNATMap.rules = filteredRules + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) - if len(m.portDNATMap.rules) == 0 { + if len(m.portDNATRules) == 0 { m.portDNATEnabled.Store(false) } return nil } -// RemoveInboundDNAT removes inbound DNAT rule for specified local address and ports. +// RemoveInboundDNAT removes an inbound DNAT rule. func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { var layerType gopacket.LayerType switch protocol { @@ -604,146 +466,55 @@ func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Prot return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) } -// translateInboundPortDNAT applies stateful port-specific DNAT translation to inbound packets. -func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder) bool { +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { if !m.portDNATEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: return false } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return false - } - - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) - srcPort := uint16(d.tcp.SrcPort) - dstPort := uint16(d.tcp.DstPort) - - if m.handleReturnTraffic(packetData, d, srcIP, dstIP, srcPort, dstPort) { - return true - } - - return m.handleNewConnection(packetData, d, srcIP, dstIP, srcPort, dstPort) } -// handleReturnTraffic processes return traffic for existing NAT connections. -func (m *Manager) handleReturnTraffic(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - if m.isTranslatedPortTraffic(srcIP, srcPort) { - return false - } +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error - if handled := m.handleExistingNATConnection(packetData, d, srcIP, dstIP, srcPort, dstPort); handled { - return true - } - - return m.handleForwardTrafficInExistingConnections(packetData, d, srcIP, dstIP, srcPort, dstPort) -} - -// isTranslatedPortTraffic checks if traffic is from a translated port that should be handled by outbound reverse. -func (m *Manager) isTranslatedPortTraffic(srcIP netip.Addr, srcPort uint16) bool { +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { m.portDNATMutex.RLock() defer m.portDNATMutex.RUnlock() - for _, rule := range m.portDNATMap.rules { - if rule.protocol == layers.LayerTypeTCP && rule.targetPort == srcPort && - rule.targetIP.Unmap().Compare(srcIP.Unmap()) == 0 { - return true + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue } - } - return false -} -// handleExistingNATConnection processes return traffic for existing NAT connections. -func (m *Manager) handleExistingNATConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists { - if err := m.rewriteTCPDestinationPort(packetData, d, natConn.originalPort); err != nil { - m.logger.Error1(errRewriteTCPDestinationPort, err) + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { return false } - m.logger.Trace4("Inbound Port DNAT (return): %s:%d -> %s:%d", dstIP, srcPort, dstIP, natConn.originalPort) - return true - } - return false -} -// handleForwardTrafficInExistingConnections processes forward traffic in existing connections. -func (m *Manager) handleForwardTrafficInExistingConnections(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - m.portDNATMutex.RLock() - defer m.portDNATMutex.RUnlock() - - for _, rule := range m.portDNATMap.rules { - if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort { - continue - } - if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 { + if rule.origPort != port { continue } - if _, exists := m.portNATTracker.getConnectionNAT(srcIP, dstIP, srcPort, rule.targetPort); !exists { - continue - } - - if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil { - m.logger.Error1(errRewriteTCPDestinationPort, err) + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) return false } + d.dnatOrigPort = rule.origPort return true } - return false } -// handleNewConnection processes new connections that match port DNAT rules. -func (m *Manager) handleNewConnection(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - m.portDNATMutex.RLock() - defer m.portDNATMutex.RUnlock() - - for _, rule := range m.portDNATMap.rules { - if m.applyPortDNATRule(packetData, d, rule, srcIP, dstIP, srcPort, dstPort) { - return true - } - } - return false -} - -// applyPortDNATRule applies a specific port DNAT rule if conditions are met. -func (m *Manager) applyPortDNATRule(packetData []byte, d *decoder, rule portDNATRule, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) bool { - if rule.protocol != layers.LayerTypeTCP || rule.sourcePort != dstPort { - return false - } - - if rule.targetIP.Unmap().Compare(dstIP.Unmap()) != 0 { - return false - } - - if !m.portNATTracker.shouldApplyNAT(srcIP, dstIP, dstPort) { - return false - } - - if err := m.rewriteTCPDestinationPort(packetData, d, rule.targetPort); err != nil { - m.logger.Error1(errRewriteTCPDestinationPort, err) - return false - } - - m.portNATTracker.trackConnection(srcIP, dstIP, srcPort, dstPort, rule) - m.logger.Trace8("Inbound Port DNAT (new): %s:%d -> %s:%d (tracked: %s:%d -> %s:%d)", dstIP, rule.sourcePort, dstIP, rule.targetPort, srcIP, srcPort, dstIP, rule.targetPort) - return true -} - -// rewriteTCPDestinationPort rewrites the destination port in a TCP packet and updates checksum. -func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPort uint16) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return ErrIPv4Only - } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return fmt.Errorf("not a TCP packet") - } - +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { return errInvalidIPHeaderLength @@ -754,9 +525,9 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo return fmt.Errorf("packet too short for TCP header") } - oldPort := binary.BigEndian.Uint16(packetData[tcpStart+2 : tcpStart+4]) - - binary.BigEndian.PutUint16(packetData[tcpStart+2:tcpStart+4], newPort) + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) if len(packetData) >= tcpStart+18 { checksumOffset := tcpStart + 16 @@ -773,75 +544,34 @@ func (m *Manager) rewriteTCPDestinationPort(packetData []byte, d *decoder, newPo return nil } -// rewriteTCPSourcePort rewrites the source port in a TCP packet and updates checksum. -func (m *Manager) rewriteTCPSourcePort(packetData []byte, d *decoder, newPort uint16) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return ErrIPv4Only - } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return fmt.Errorf("not a TCP packet") - } - +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { return errInvalidIPHeaderLength } - tcpStart := ipHeaderLen - if len(packetData) < tcpStart+4 { - return fmt.Errorf("packet too short for TCP header") + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") } - oldPort := binary.BigEndian.Uint16(packetData[tcpStart : tcpStart+2]) + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) - binary.BigEndian.PutUint16(packetData[tcpStart:tcpStart+2], newPort) - - if len(packetData) >= tcpStart+18 { - checksumOffset := tcpStart + 16 + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) - var oldPortBytes, newPortBytes [2]byte - binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) - binary.BigEndian.PutUint16(newPortBytes[:], newPort) - - newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) - binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } } return nil } - -// translateOutboundPortReverse applies stateful reverse port DNAT to outbound return traffic for SSH redirection. -func (m *Manager) translateOutboundPortReverse(packetData []byte, d *decoder) bool { - if !m.portDNATEnabled.Load() { - return false - } - - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - - if len(d.decoded) < 2 || d.decoded[1] != layers.LayerTypeTCP { - return false - } - - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) - srcPort := uint16(d.tcp.SrcPort) - dstPort := uint16(d.tcp.DstPort) - - // For outbound reverse, we need to find the connection using the same key as when it was stored - // Connection was stored as: srcIP, dstIP, srcPort, translatedPort - // So for return traffic (srcIP=server, dstIP=client), we need: dstIP, srcIP, dstPort, srcPort - if natConn, exists := m.portNATTracker.getConnectionNAT(dstIP, srcIP, dstPort, srcPort); exists { - if err := m.rewriteTCPSourcePort(packetData, d, natConn.rule.sourcePort); err != nil { - m.logger.Error1("rewrite TCP source port: %v", err) - return false - } - - return true - } - - return false -} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d2599e577 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) { func BenchmarkDNATConcurrency(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) { b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) { func BenchmarkDNATMemoryAllocations(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -414,3 +415,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_stateful_test.go b/client/firewall/uspfilter/nat_stateful_test.go index 5c7853397..21c6da06e 100644 --- a/client/firewall/uspfilter/nat_stateful_test.go +++ b/client/firewall/uspfilter/nat_stateful_test.go @@ -7,15 +7,15 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) -// TestStatefulNATBidirectionalSSH tests that stateful NAT prevents interference -// when two peers try to SSH to each other simultaneously -func TestStatefulNATBidirectionalSSH(t *testing.T) { +// TestPortDNATBasic tests basic port DNAT functionality +func TestPortDNATBasic(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -30,49 +30,28 @@ func TestStatefulNATBidirectionalSSH(t *testing.T) { require.NoError(t, err) // Scenario: Peer A connects to Peer B on port 22 (should get NAT) - // This simulates: ssh user@100.10.0.51 packetAtoB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) - translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, parsePacket(t, packetAtoB)) + d := parsePacket(t, packetAtoB) + translatedAtoB := manager.translateInboundPortDNAT(packetAtoB, d, peerA, peerB) require.True(t, translatedAtoB, "Peer A to Peer B should be translated (NAT applied)") // Verify port was translated to 22022 - d := parsePacket(t, packetAtoB) + d = parsePacket(t, packetAtoB) require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be rewritten to 22022") - // Verify NAT connection is tracked (with translated port as key) - natConn, exists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022) - require.True(t, exists, "NAT connection should be tracked") - require.Equal(t, uint16(22), natConn.originalPort, "Original port should be stored") - - // Scenario: Peer B tries to connect to Peer A on port 22 (should NOT get NAT) - // This simulates the reverse direction to prevent interference - packetBtoA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22) - translatedBtoA := manager.translateInboundPortDNAT(packetBtoA, parsePacket(t, packetBtoA)) - require.False(t, translatedBtoA, "Peer B to Peer A should NOT be translated (prevent interference)") - - // Verify port was NOT translated - d2 := parsePacket(t, packetBtoA) - require.Equal(t, uint16(22), uint16(d2.tcp.DstPort), "Port should remain 22 (no translation)") - - // Verify no reverse NAT connection is tracked - _, reverseExists := manager.portNATTracker.getConnectionNAT(peerB, peerA, 54322, 22) - require.False(t, reverseExists, "Reverse NAT connection should NOT be tracked") - - // Scenario: Return traffic from Peer B (SSH server) to Peer A (should be reverse translated) + // Scenario: Return traffic from Peer B to Peer A should NOT be translated + // (prevents double NAT - original port stored in conntrack) returnPacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 22022, 54321) - translatedReturn := manager.translateOutboundPortReverse(returnPacket, parsePacket(t, returnPacket)) - require.True(t, translatedReturn, "Return traffic should be reverse translated") - - // Verify return traffic port was translated back to 22 - d3 := parsePacket(t, returnPacket) - require.Equal(t, uint16(22), uint16(d3.tcp.SrcPort), "Return traffic source port should be 22") + d2 := parsePacket(t, returnPacket) + translatedReturn := manager.translateInboundPortDNAT(returnPacket, d2, peerB, peerA) + require.False(t, translatedReturn, "Return traffic from same IP should not be translated") } -// TestStatefulNATConnectionCleanup tests connection cleanup functionality -func TestStatefulNATConnectionCleanup(t *testing.T) { +// TestPortDNATMultipleRules tests multiple port DNAT rules +func TestPortDNATMultipleRules(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -88,24 +67,19 @@ func TestStatefulNATConnectionCleanup(t *testing.T) { err = manager.addPortRedirection(peerB, layers.LayerTypeTCP, 22, 22022) require.NoError(t, err) - // Establish connection with NAT - packet := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) - translated := manager.translateInboundPortDNAT(packet, parsePacket(t, packet)) - require.True(t, translated, "Initial connection should be translated") + // Test traffic to peer B gets translated + packetToB := generateDNATTestPacket(t, peerA, peerB, layers.IPProtocolTCP, 54321, 22) + d1 := parsePacket(t, packetToB) + translatedToB := manager.translateInboundPortDNAT(packetToB, d1, peerA, peerB) + require.True(t, translatedToB, "Traffic to peer B should be translated") + d1 = parsePacket(t, packetToB) + require.Equal(t, uint16(22022), uint16(d1.tcp.DstPort), "Port should be 22022") - // Verify connection is tracked (using translated port as key) - _, exists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022) - require.True(t, exists, "Connection should be tracked") - - // Clean up connection - manager.portNATTracker.cleanupConnection(peerA, peerB, 54321) - - // Verify connection is no longer tracked (using translated port as key) - _, stillExists := manager.portNATTracker.getConnectionNAT(peerA, peerB, 54321, 22022) - require.False(t, stillExists, "Connection should be cleaned up") - - // Verify new connection from opposite direction now works - reversePacket := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22) - reverseTranslated := manager.translateInboundPortDNAT(reversePacket, parsePacket(t, reversePacket)) - require.True(t, reverseTranslated, "Reverse connection should now work after cleanup") + // Test traffic to peer A gets translated + packetToA := generateDNATTestPacket(t, peerB, peerA, layers.IPProtocolTCP, 54322, 22) + d2 := parsePacket(t, packetToA) + translatedToA := manager.translateInboundPortDNAT(packetToA, d2, peerB, peerA) + require.True(t, translatedToA, "Traffic to peer A should be translated") + d2 = parsePacket(t, packetToA) + require.Equal(t, uint16(22022), uint16(d2.tcp.DstPort), "Port should be 22022") } diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 4c43077bc..400d61020 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -1,17 +1,15 @@ package uspfilter import ( - "io" - "net" "net/netip" "testing" - "time" "github.com/google/gopacket" "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -19,7 +17,7 @@ import ( func TestDNATTranslationCorrectness(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -103,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { func TestDNATMappingManagement(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -148,521 +146,110 @@ func TestDNATMappingManagement(t *testing.T) { require.Error(t, err, "Should error when removing non-existent mapping") } -// TestSSHPortRedirection tests SSH port redirection functionality -func TestSSHPortRedirection(t *testing.T) { +func TestInboundPortDNAT(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) }() - // Define NetBird network range - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") - // Add SSH port redirection rule - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } - // Verify port DNAT is enabled - require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled") - require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) - // Verify the rule configuration - rule := manager.portDNATMap.rules[0] - require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol) - require.Equal(t, uint16(22), rule.sourcePort) - require.Equal(t, uint16(22022), rule.targetPort) - require.Equal(t, peerIP, rule.targetIP) + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) - // Test inbound SSH packet (client -> peer:22, should redirect to peer:22022) - inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - originalInbound := make([]byte, len(inboundPacket)) - copy(originalInbound, inboundPacket) + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") - // Process inbound packet - translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket)) - require.True(t, translated, "Inbound SSH packet should be translated") - - // Verify destination port was changed from 22 to 22022 - d := parsePacket(t, inboundPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Destination port should be rewritten to 22022") - - // Verify destination IP remains unchanged - dstIPAfter := netip.AddrFrom4([4]byte{inboundPacket[16], inboundPacket[17], inboundPacket[18], inboundPacket[19]}) - require.Equal(t, peerIP, dstIPAfter, "Destination IP should remain unchanged") - - // Test outbound return packet (peer:22022 -> client, should rewrite source port to 22) - outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321) - originalOutbound := make([]byte, len(outboundPacket)) - copy(originalOutbound, outboundPacket) - - // Process outbound return packet - reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket)) - require.True(t, reversed, "Outbound return packet should be reverse translated") - - // Verify source port was changed from 22022 to 22 - d = parsePacket(t, outboundPacket) - require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Source port should be rewritten to 22") - - // Verify source IP remains unchanged - srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]}) - require.Equal(t, peerIP, srcIPAfter, "Source IP should remain unchanged") - - // Test removal of SSH port redirection - err = manager.RemoveInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal") - require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal") -} - -// TestSSHPortRedirectionNetworkFiltering tests that SSH redirection only applies to specified networks -func TestSSHPortRedirectionNetworkFiltering(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define NetBird network range - peerInNetwork := netip.MustParseAddr("100.10.0.50") - peerOutsideNetwork := netip.MustParseAddr("192.168.1.50") - clientIP := netip.MustParseAddr("100.10.0.100") - - // Add SSH port redirection rule for NetBird network only - err = manager.AddInboundDNAT(peerInNetwork, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Test SSH packet to peer within NetBird network (should be redirected) - inNetworkPacket := generateDNATTestPacket(t, clientIP, peerInNetwork, layers.IPProtocolTCP, 54321, 22) - translated := manager.translateInboundPortDNAT(inNetworkPacket, parsePacket(t, inNetworkPacket)) - require.True(t, translated, "SSH packet to NetBird peer should be translated") - - // Verify port was changed - d := parsePacket(t, inNetworkPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected for NetBird peer") - - // Test SSH packet to peer outside NetBird network (should NOT be redirected) - outOfNetworkPacket := generateDNATTestPacket(t, clientIP, peerOutsideNetwork, layers.IPProtocolTCP, 54321, 22) - originalOutOfNetwork := make([]byte, len(outOfNetworkPacket)) - copy(originalOutOfNetwork, outOfNetworkPacket) - - notTranslated := manager.translateInboundPortDNAT(outOfNetworkPacket, parsePacket(t, outOfNetworkPacket)) - require.False(t, notTranslated, "SSH packet to non-NetBird peer should NOT be translated") - - // Verify packet was not modified - require.Equal(t, originalOutOfNetwork, outOfNetworkPacket, "Packet to non-NetBird peer should remain unchanged") -} - -// TestSSHPortRedirectionNonTCPTraffic tests that only TCP traffic is affected -func TestSSHPortRedirectionNonTCPTraffic(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define NetBird network range - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") - - // Add SSH port redirection rule - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Test UDP packet on port 22 (should NOT be redirected) - udpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolUDP, 54321, 22) - originalUDP := make([]byte, len(udpPacket)) - copy(originalUDP, udpPacket) - - translated := manager.translateInboundPortDNAT(udpPacket, parsePacket(t, udpPacket)) - require.False(t, translated, "UDP packet should NOT be translated by SSH port redirection") - require.Equal(t, originalUDP, udpPacket, "UDP packet should remain unchanged") - - // Test ICMP packet (should NOT be redirected) - icmpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolICMPv4, 0, 0) - originalICMP := make([]byte, len(icmpPacket)) - copy(originalICMP, icmpPacket) - - translated = manager.translateInboundPortDNAT(icmpPacket, parsePacket(t, icmpPacket)) - require.False(t, translated, "ICMP packet should NOT be translated by SSH port redirection") - require.Equal(t, originalICMP, icmpPacket, "ICMP packet should remain unchanged") -} - -// TestSSHPortRedirectionNonSSHPorts tests that only port 22 is redirected -func TestSSHPortRedirectionNonSSHPorts(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define NetBird network range - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") - - // Add SSH port redirection rule - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Test TCP packet on port 80 (should NOT be redirected) - httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80) - originalHTTP := make([]byte, len(httpPacket)) - copy(originalHTTP, httpPacket) - - translated := manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket)) - require.False(t, translated, "Non-SSH TCP packet should NOT be translated") - require.Equal(t, originalHTTP, httpPacket, "Non-SSH TCP packet should remain unchanged") - - // Test TCP packet on port 443 (should NOT be redirected) - httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443) - originalHTTPS := make([]byte, len(httpsPacket)) - copy(originalHTTPS, httpsPacket) - - translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket)) - require.False(t, translated, "Non-SSH TCP packet should NOT be translated") - require.Equal(t, originalHTTPS, httpsPacket, "Non-SSH TCP packet should remain unchanged") - - // Test TCP packet on port 22 (SHOULD be redirected) - sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - translated = manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket)) - require.True(t, translated, "SSH TCP packet should be translated") - - // Verify port was changed to 22022 - d := parsePacket(t, sshPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH port should be redirected to 22022") -} - -// TestFlexiblePortRedirection tests the flexible port redirection functionality -func TestFlexiblePortRedirection(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define peer and client IPs - peerIP := netip.MustParseAddr("10.0.0.50") - clientIP := netip.MustParseAddr("10.0.0.100") - - // Add custom port redirection: TCP port 8080 -> 3000 for peer IP - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000) - require.NoError(t, err) - - // Verify port DNAT is enabled - require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled") - require.Len(t, manager.portDNATMap.rules, 1, "Should have one port DNAT rule") - - // Verify the rule configuration - rule := manager.portDNATMap.rules[0] - require.Equal(t, gopacket.LayerType(layers.LayerTypeTCP), rule.protocol) - require.Equal(t, uint16(8080), rule.sourcePort) - require.Equal(t, uint16(3000), rule.targetPort) - require.Equal(t, peerIP, rule.targetIP) - - // Test inbound packet (client -> peer:8080, should redirect to peer:3000) - inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 8080) - translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket)) - require.True(t, translated, "Inbound packet should be translated") - - // Verify destination port was changed from 8080 to 3000 - d := parsePacket(t, inboundPacket) - require.Equal(t, uint16(3000), uint16(d.tcp.DstPort), "Destination port should be rewritten to 3000") - - // Test outbound return packet (peer:3000 -> client, should rewrite source port to 8080) - outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 3000, 54321) - reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket)) - require.True(t, reversed, "Outbound return packet should be reverse translated") - - // Verify source port was changed from 3000 to 8080 - d = parsePacket(t, outboundPacket) - require.Equal(t, uint16(8080), uint16(d.tcp.SrcPort), "Source port should be rewritten to 8080") - - // Test removal of port redirection - err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 8080, 3000) - require.NoError(t, err) - require.False(t, manager.portDNATEnabled.Load(), "Port DNAT should be disabled after removal") - require.Len(t, manager.portDNATMap.rules, 0, "Should have no port DNAT rules after removal") -} - -// TestMultiplePortRedirections tests multiple port redirection rules -func TestMultiplePortRedirections(t *testing.T) { - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Define peer and client IPs - peerIP := netip.MustParseAddr("172.16.0.50") - clientIP := netip.MustParseAddr("172.16.0.100") - - // Add multiple port redirections for peer IP - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 22, 22022) // SSH - require.NoError(t, err) - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) // HTTP - require.NoError(t, err) - err = manager.addPortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 443, 8443) // HTTPS - require.NoError(t, err) - - // Verify all rules are present - require.True(t, manager.portDNATEnabled.Load(), "Port DNAT should be enabled") - require.Len(t, manager.portDNATMap.rules, 3, "Should have three port DNAT rules") - - // Test SSH redirection (22 -> 22022) - sshPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket)) - require.True(t, translated, "SSH packet should be translated") - d := parsePacket(t, sshPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "SSH should redirect to 22022") - - // Test HTTP redirection (80 -> 8080) - httpPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80) - translated = manager.translateInboundPortDNAT(httpPacket, parsePacket(t, httpPacket)) - require.True(t, translated, "HTTP packet should be translated") - d = parsePacket(t, httpPacket) - require.Equal(t, uint16(8080), uint16(d.tcp.DstPort), "HTTP should redirect to 8080") - - // Test HTTPS redirection (443 -> 8443) - httpsPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443) - translated = manager.translateInboundPortDNAT(httpsPacket, parsePacket(t, httpsPacket)) - require.True(t, translated, "HTTPS packet should be translated") - d = parsePacket(t, httpsPacket) - require.Equal(t, uint16(8443), uint16(d.tcp.DstPort), "HTTPS should redirect to 8443") - - // Test removing one rule (HTTP) - err = manager.removePortRedirection(peerIP, gopacket.LayerType(layers.LayerTypeTCP), 80, 8080) - require.NoError(t, err) - require.Len(t, manager.portDNATMap.rules, 2, "Should have two rules after removing HTTP rule") - - // Verify HTTP is no longer redirected - httpPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 80) - originalHTTP := make([]byte, len(httpPacket2)) - copy(originalHTTP, httpPacket2) - translated = manager.translateInboundPortDNAT(httpPacket2, parsePacket(t, httpPacket2)) - require.False(t, translated, "HTTP packet should NOT be translated after rule removal") - require.Equal(t, originalHTTP, httpPacket2, "HTTP packet should remain unchanged") - - // Verify SSH and HTTPS still work - sshPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) - translated = manager.translateInboundPortDNAT(sshPacket2, parsePacket(t, sshPacket2)) - require.True(t, translated, "SSH should still be translated") - - httpsPacket2 := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 443) - translated = manager.translateInboundPortDNAT(httpsPacket2, parsePacket(t, httpsPacket2)) - require.True(t, translated, "HTTPS should still be translated") -} - -// TestSSHPortRedirectionEndToEnd tests actual network delivery through sockets -func TestSSHPortRedirectionEndToEnd(t *testing.T) { - // Start a mock SSH server on port 22022 (NetBird SSH server) - mockSSHServer, err := net.Listen("tcp", "127.0.0.1:22022") - require.NoError(t, err, "Should be able to bind to NetBird SSH port") - defer func() { - require.NoError(t, mockSSHServer.Close()) - }() - - // Handle connections on the SSH server - serverReceivedData := make(chan string, 1) - go func() { - for { - conn, err := mockSSHServer.Accept() - if err != nil { - return // Server closed + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) } - go func(conn net.Conn) { - defer func() { - require.NoError(t, conn.Close()) - }() - buf := make([]byte, 1024) - n, err := conn.Read(buf) - if err != nil && err != io.EOF { - t.Logf("Server read error: %v", err) - return - } + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") - receivedData := string(buf[:n]) - serverReceivedData <- receivedData - - // Echo back a response - _, err = conn.Write([]byte("SSH-2.0-MockNetBirdSSH\r\n")) - if err != nil { - t.Logf("Server write error: %v", err) - } - }(conn) - } - }() - - // Give server time to start - time.Sleep(100 * time.Millisecond) - - // This test demonstrates what SHOULD happen after port redirection: - // 1. Client connects to 127.0.0.1:22 (standard SSH port) - // 2. Firewall redirects to 127.0.0.1:22022 (NetBird SSH server) - // 3. NetBird SSH server receives the connection - - t.Run("DirectConnectionToNetBirdSSHPort", func(t *testing.T) { - // This simulates what should happen AFTER port redirection - // Connect directly to 22022 (where NetBird SSH server listens) - conn, err := net.DialTimeout("tcp", "127.0.0.1:22022", 5*time.Second) - require.NoError(t, err, "Should connect to NetBird SSH server") - defer func() { - require.NoError(t, conn.Close()) - }() - - // Send SSH client identification - testData := "SSH-2.0-TestClient\r\n" - _, err = conn.Write([]byte(testData)) - require.NoError(t, err, "Should send data to SSH server") - - // Verify server received the data - select { - case received := <-serverReceivedData: - require.Equal(t, testData, received, "Server should receive client data") - case <-time.After(2 * time.Second): - t.Fatal("Server did not receive data within timeout") - } - - // Read server response - buf := make([]byte, 1024) - if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil { - t.Logf("failed to set read deadline: %v", err) - } - n, err := conn.Read(buf) - require.NoError(t, err, "Should read server response") - - response := string(buf[:n]) - require.Equal(t, "SSH-2.0-MockNetBirdSSH\r\n", response, "Should receive SSH server identification") - }) - - t.Run("PortRedirectionSimulation", func(t *testing.T) { - // This test simulates the port redirection process - // Note: This doesn't test the actual userspace packet interception, - // but demonstrates the expected behavior - - t.Log("NOTE: This test demonstrates expected behavior after implementing") - t.Log("full userspace packet interception. Currently, we test packet") - t.Log("translation logic separately from actual network delivery.") - - // In a real implementation with userspace packet interception: - // 1. Client would connect to 127.0.0.1:22 - // 2. Userspace firewall would intercept packets - // 3. translateInboundPortDNAT would rewrite port 22 -> 22022 - // 4. Packets would be delivered to 127.0.0.1:22022 - // 5. NetBird SSH server would receive the connection - - // For now, we verify that the packet translation logic works correctly - // (this is tested in other test functions) and that the target server - // is reachable (tested above) - - clientIP := netip.MustParseAddr("127.0.0.1") - serverIP := netip.MustParseAddr("127.0.0.1") - - // Create manager with SSH port redirection - manager, err := Create(&IFaceMock{ - SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) - require.NoError(t, err) - defer func() { - require.NoError(t, manager.Close(nil)) - }() - - // Add SSH port redirection for localhost (for testing) - err = manager.AddInboundDNAT(netip.MustParseAddr("127.0.0.1"), firewall.ProtocolTCP, 22, 22022) - require.NoError(t, err) - - // Generate packet: client connecting to server:22 - sshPacket := generateDNATTestPacket(t, clientIP, serverIP, layers.IPProtocolTCP, 54321, 22) - originalPacket := make([]byte, len(sshPacket)) - copy(originalPacket, sshPacket) - - // Apply port redirection - translated := manager.translateInboundPortDNAT(sshPacket, parsePacket(t, sshPacket)) - require.True(t, translated, "SSH packet should be translated") - - // Verify port was redirected from 22 to 22022 - d := parsePacket(t, sshPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Port should be redirected to NetBird SSH server") - require.NotEqual(t, originalPacket, sshPacket, "Packet should be modified") - - t.Log("✓ Packet translation verified: port 22 redirected to 22022") - t.Log("✓ Target SSH server (port 22022) is reachable and responsive") - t.Log("→ Integration complete: SSH port redirection ready for userspace interception") - }) + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } } -// TestFullSSHRedirectionWorkflow demonstrates the complete SSH redirection workflow -func TestFullSSHRedirectionWorkflow(t *testing.T) { - t.Log("=== SSH Port Redirection Workflow Test ===") - t.Log("This test demonstrates the complete SSH redirection process:") - t.Log("1. Client connects to peer:22 (standard SSH)") - t.Log("2. Userspace firewall intercepts and redirects to peer:22022") - t.Log("3. NetBird SSH server receives connection on port 22022") - t.Log("4. Return traffic is reverse-translated (22022 -> 22)") - - // Setup test environment +func TestInboundPortDNATNegative(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) }() - // Define NetBird network and peer IPs - peerIP := netip.MustParseAddr("100.10.0.50") - clientIP := netip.MustParseAddr("100.10.0.100") + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") - // Step 1: Configure SSH port redirection - err = manager.AddInboundDNAT(peerIP, firewall.ProtocolTCP, 22, 22022) + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) require.NoError(t, err) - t.Log("✓ SSH port redirection configured for NetBird network") - // Step 2: Simulate inbound SSH connection (client -> peer:22) - t.Log("→ Simulating: ssh user@100.10.0.50") - inboundPacket := generateDNATTestPacket(t, clientIP, peerIP, layers.IPProtocolTCP, 54321, 22) + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } - // Step 3: Apply inbound port redirection - translated := manager.translateInboundPortDNAT(inboundPacket, parsePacket(t, inboundPacket)) - require.True(t, translated, "Inbound SSH packet should be redirected") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) - d := parsePacket(t, inboundPacket) - require.Equal(t, uint16(22022), uint16(d.tcp.DstPort), "Should redirect to NetBird SSH server port") - t.Log("✓ Inbound packet redirected: 100.10.0.50:22 → 100.10.0.50:22022") + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) - // Step 4: Simulate outbound return traffic (peer:22022 -> client) - t.Log("→ Simulating return traffic from NetBird SSH server") - outboundPacket := generateDNATTestPacket(t, peerIP, clientIP, layers.IPProtocolTCP, 22022, 54321) - - // Step 5: Apply outbound reverse translation - reversed := manager.translateOutboundPortReverse(outboundPacket, parsePacket(t, outboundPacket)) - require.True(t, reversed, "Outbound return packet should be reverse translated") - - d = parsePacket(t, outboundPacket) - require.Equal(t, uint16(22), uint16(d.tcp.SrcPort), "Should restore original SSH port") - t.Log("✓ Outbound packet reverse translated: 100.10.0.50:22022 → 100.10.0.50:22") - - // Step 6: Verify client sees standard SSH connection - srcIPAfter := netip.AddrFrom4([4]byte{outboundPacket[12], outboundPacket[13], outboundPacket[14], outboundPacket[15]}) - require.Equal(t, peerIP, srcIPAfter, "Client should see traffic from peer IP") - t.Log("✓ Client receives traffic from 100.10.0.50:22 (transparent redirection)") - - t.Log("=== SSH Port Redirection Workflow Complete ===") - t.Log("Result: Standard SSH clients can connect to NetBird peers using:") - t.Log(" ssh user@100.10.0.50") - t.Log("Instead of:") - t.Log(" ssh user@100.10.0.50 -p 22022") + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..d9f9f1aa8 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -10,6 +10,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) if !statefulMode { @@ -104,6 +105,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +129,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +158,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +186,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +213,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +239,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +259,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +279,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +304,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +320,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +340,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +363,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +387,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +409,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +435,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 6aff53b92..7763f2417 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" "runtime" "time" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -17,6 +20,9 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) +// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready +var ErrConnectionShutdown = errors.New("connection shutdown before ready") + // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// waitForConnectionReady blocks until the connection becomes ready or fails. +// Returns an error if the connection times out, is cancelled, or enters shutdown state. +func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { + conn.Connect() + + state := conn.GetState() + for state != connectivity.Ready && state != connectivity.Shutdown { + if !conn.WaitForStateChange(ctx, state) { + return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) + } + state = conn.GetState() + } + + if state == connectivity.Shutdown { + return ErrConnectionShutdown + } + + return nil +} + // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - conn, err := grpc.DialContext( - connCtx, + conn, err := grpc.NewClient( addr, transportOption, WithCustomDialer(tlsEnabled, component), - grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - log.Printf("DialContext error: %v", err) + return nil, fmt.Errorf("new client: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForConnectionReady(ctx, conn); err != nil { + _ = conn.Close() return nil, err } diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index 96f347c64..479575996 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { +func WithCustomDialer(_ bool, _ string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { - log.Errorf("Failed to dial: %s", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index c276f445d..4bc0fd800 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/netflow" @@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) diff --git a/client/internal/connect.go b/client/internal/connect.go index 743307fd0..6ad5f264b 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } <-engineCtx.Done() + c.engineMutex.Lock() - if c.engine != nil && c.engine.wgInterface != nil { - log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) - if err := c.engine.Stop(); err != nil { + engine := c.engine + c.engine = nil + c.engineMutex.Unlock() + + if engine != nil && engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } - c.engine = nil } - c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() @@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType { } func (c *ConnectClient) Stop() error { - if c == nil { - return nil + engine := c.Engine() + if engine != nil { + if err := engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } } - c.engineMutex.Lock() - defer c.engineMutex.Unlock() - - if c.engine == nil { - return nil - } - if err := c.engine.Stop(); err != nil { - return fmt.Errorf("stop engine: %w", err) - } - return nil } diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index d39910cb4..c3d006a2b 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -47,7 +47,7 @@ nftables.txt: Anonymized nftables rules with packet counters, if --system-info f resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. -state.json: Anonymized client state dump containing netbird states. +state.json: Anonymized client state dump containing netbird states for the active profile. mutex.prof: Mutex profiling information. goroutine.prof: Goroutine profiling information. block.prof: Block profiling information. @@ -576,6 +576,8 @@ func (g *BundleGenerator) addStateFile() error { return nil } + log.Debugf("Adding state file from: %s", path) + data, err := os.ReadFile(path) if err != nil { if errors.Is(err, fs.ErrNotExist) { diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index a14a01f40..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -17,6 +17,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/client/internal/winregistry" ) var ( @@ -178,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -197,6 +192,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, ".")) } + if err := r.removeDNSMatchPolicies(); err != nil { + log.Errorf("cleanup old dns match policies: %s", err) + } + if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) if err != nil { @@ -204,19 +203,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager } r.nrptEntryCount = count } else { - if err := r.removeDNSMatchPolicies(); err != nil { - return fmt.Errorf("remove dns match policies: %w", err) - } r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -227,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) @@ -273,9 +273,9 @@ func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []s return fmt.Errorf("remove existing dns policy: %w", err) } - regKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) + regKey, _, err := winregistry.CreateVolatileKey(registry.LOCAL_MACHINE, policyPath, registry.SET_VALUE) if err != nil { - return fmt.Errorf("create registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) + return fmt.Errorf("create volatile registry key HKEY_LOCAL_MACHINE\\%s: %w", policyPath, err) } defer closer(regKey) diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go new file mode 100644 index 000000000..19496bf5a --- /dev/null +++ b/client/internal/dns/host_windows_test.go @@ -0,0 +1,102 @@ +package dns + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/registry" +) + +// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up +// when the number of match domains decreases between configuration changes. +func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + config5 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + {Domain: "domain3.com", MatchOnly: true}, + {Domain: "domain4.com", MatchOnly: true}, + {Domain: "domain5.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config5, nil) + require.NoError(t, err) + + // Verify all 5 entries exist + for i := 0; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after first config", i) + } + + config2 := HostDNSConfig{ + ServerIP: testIP, + Domains: []DomainConfig{ + {Domain: "domain1.com", MatchOnly: true}, + {Domain: "domain2.com", MatchOnly: true}, + }, + } + + err = cfg.applyDNSConfig(config2, nil) + require.NoError(t, err) + + // Verify first 2 entries exist + for i := 0; i < 2; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "Entry %d should exist after second config", i) + } + + // Verify entries 2-4 are cleaned up + for i := 2; i < 5; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) + } +} + +func registryKeyExists(path string) (bool, error) { + k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE) + if err != nil { + if err == registry.ErrNotExist { + return false, nil + } + return false, err + } + k.Close() + return true, nil +} + +func cleanupRegistryKeys(*testing.T) { + cfg := ®istryConfigurator{nrptEntryCount: 10} + _ = cfg.removeDNSMatchPolicies() +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8cb886203..afaf0579f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface { // DefaultServer dns server object type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + shutdownWg sync.WaitGroup // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. // This is different from ServiceEnable=false from management which completely disables the DNS service. disableSys bool @@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { s.ctxCancel() + s.shutdownWg.Wait() s.mux.Lock() defer s.mux.Unlock() @@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.applyHostConfig() + s.shutdownWg.Add(1) go func() { - // persist dns state right away + defer s.shutdownWg.Done() if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 11575d500..451b83f92 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface, false, flowLogger) + pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 7a262fa4c..aef16a8cf 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -33,7 +34,7 @@ type firewaller interface { } type DNSForwarder struct { - listenAddress string + listenAddress netip.AddrPort ttl uint32 statusRecorder *peer.Status @@ -47,9 +48,11 @@ type DNSForwarder struct { firewall firewaller resolver resolver cache *cache + + wgIface wgIface } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, @@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat statusRecorder: statusRecorder, resolver: net.DefaultResolver, cache: newCache(), + wgIface: wgIface, } } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("starting DNS forwarder on address=%s", f.listenAddress) + var netstackNet *netstack.Net + if f.wgIface != nil { + netstackNet = f.wgIface.GetNet() + } + + addrDesc := f.listenAddress.String() + if netstackNet != nil { + addrDesc = fmt.Sprintf("netstack %s", f.listenAddress) + } + log.Infof("starting DNS forwarder on address=%s", addrDesc) + + udpLn, err := f.createUDPListener(netstackNet) + if err != nil { + return fmt.Errorf("create UDP listener: %w", err) + } + + tcpLn, err := f.createTCPListener(netstackNet) + if err != nil { + return fmt.Errorf("create TCP listener: %w", err) + } - // UDP server mux := dns.NewServeMux() f.mux = mux mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ - Addr: f.listenAddress, - Net: "udp", - Handler: mux, + PacketConn: udpLn, + Handler: mux, } - // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ - Addr: f.listenAddress, - Net: "tcp", - Handler: tcpMux, + Listener: tcpLn, + Handler: tcpMux, } f.UpdateDomains(entries) @@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { errCh := make(chan error, 2) go func() { - log.Infof("DNS UDP listener running on %s", f.listenAddress) - errCh <- f.dnsServer.ListenAndServe() + log.Infof("DNS UDP listener running on %s", addrDesc) + errCh <- f.dnsServer.ActivateAndServe() }() go func() { - log.Infof("DNS TCP listener running on %s", f.listenAddress) - errCh <- f.tcpServer.ListenAndServe() + log.Infof("DNS TCP listener running on %s", addrDesc) + errCh <- f.tcpServer.ActivateAndServe() }() - // return the first error we get (e.g. bind failure or shutdown) return <-errCh } +func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) { + if netstackNet != nil { + return netstackNet.ListenUDPAddrPort(f.listenAddress) + } + + return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress)) +} + +func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) { + if netstackNet != nil { + return netstackNet.ListenTCPAddrPort(f.listenAddress) + } + + return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress)) +} + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index c1c95a2c1..4d0b96a75 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) } - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString(tt.configuredDomain) @@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { mockResolver := &MockResolver{} // Set up forwarder - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Create entries and track sets @@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Configure a single domain @@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) d, err := domain.FromString(tt.configured) require.NoError(t, err) @@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) { // Test that large UDP responses are truncated with TC bit set mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, _ := domain.FromString("example.com") @@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { // a subsequent upstream failure still returns a successful response from cache. func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Verifies that cache normalization works across casing and trailing dot variations. func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("ExAmPlE.CoM") @@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Set up complex overlapping patterns @@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) { // Test handling of malformed query with no questions - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) query := &dns.Msg{} // Don't set any question diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a3a4ba40f..58b88d9ef 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,27 +4,34 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/wgaddr" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex +const ( + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) -const ( - dnsTTL = 60 //seconds -) +// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder. +type wgIface interface { + GetNet() *netstack.Net + Address() wgaddr.Address +} // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. type ForwarderEntry struct { @@ -36,28 +43,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + wgIface wgIface + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func SetListenPort(port uint16) { - listenPortMu.Lock() - listenPort = port - listenPortMu.Unlock() -} - -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + wgIface: wgIface, + serverPort: serverPort, } } @@ -71,7 +80,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + localAddr := m.wgIface.Address().IP + + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) + } + } + + listenAddress := netip.AddrPortFrom(localAddr, m.serverPort) + m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface) + go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -96,6 +123,20 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + localAddr := m.wgIface.Address().IP + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + + m.unregisterNetstackServices() + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -111,7 +152,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { @@ -120,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error { dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add udp firewall rule: %w", err) } - m.fwRules = dnsRules tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add tcp firewall rule: %w", err) } + + if err := m.firewall.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + m.fwRules = dnsRules m.tcpRules = tcpRules + m.registerNetstackServices() + return nil } +func (m *Manager) registerNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, m.serverPort) + registrar.RegisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + +func (m *Manager) unregisterNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort) + registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error for _, rule := range m.fwRules { diff --git a/client/internal/engine.go b/client/internal/engine.go index ed47c0cee..1deb3d3cf 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -202,11 +202,12 @@ type Engine struct { flowManager nftypes.FlowManager // WireGuard interface monitor - wgIfaceMonitor *WGIfaceMonitor - wgIfaceMonitorWg sync.WaitGroup + wgIfaceMonitor *WGIfaceMonitor - // dns forwarder port - dnsFwdPort uint16 + // shutdownWg tracks all long-running goroutines to ensure clean shutdown + shutdownWg sync.WaitGroup + + probeStunTurn *relay.StunTurnProbe } // Peer is an instance of the Connection Peer @@ -248,7 +249,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } sm := profilemanager.NewServiceManager("") @@ -307,17 +308,12 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } + e.stopDNSForwarder() + if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } - if e.dnsForwardMgr != nil { - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil - } - if e.srWatcher != nil { e.srWatcher.Close() } @@ -334,10 +330,6 @@ func (e *Engine) Stop() error { e.cancel() } - // very ugly but we want to remove peers from the WireGuard interface first before removing interface. - // Removing peers happens in the conn.Close() asynchronously - time.Sleep(500 * time.Millisecond) - e.close() // stop flow manager after wg interface is gone @@ -345,8 +337,6 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - log.Infof("stopped Netbird Engine") - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -357,12 +347,52 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } - // Stop WireGuard interface monitor and wait for it to exit - e.wgIfaceMonitorWg.Wait() + timeout := e.calculateShutdownTimeout() + log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) + shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil { + log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout) + } + + log.Infof("stopped Netbird Engine") return nil } +// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. +func (e *Engine) calculateShutdownTimeout() time.Duration { + peerCount := len(e.peerStore.PeersPubKey()) + + baseTimeout := 10 * time.Second + perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond + timeout := baseTimeout + perPeerTimeout + + maxTimeout := 30 * time.Second + if timeout > maxTimeout { + timeout = maxTimeout + } + + return timeout +} + +// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout. +func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service @@ -492,14 +522,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // monitor WireGuard interface lifecycle and restart engine on changes e.wgIfaceMonitor = NewWGIfaceMonitor() - e.wgIfaceMonitorWg.Add(1) + e.shutdownWg.Add(1) go func() { - defer e.wgIfaceMonitorWg.Done() + defer e.shutdownWg.Done() if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { log.Infof("WireGuard interface monitor: %s, restarting engine", err) - e.restartEngine() + e.triggerClientRestart() } else if err != nil { log.Warnf("WireGuard interface monitor: %s", err) } @@ -515,7 +545,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil @@ -899,7 +929,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() info, err := system.GetInfoWithChecks(e.ctx, e.checks) if err != nil { log.Warnf("failed to get system info with checks: %v", err) @@ -1014,10 +1046,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + + if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) } + e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) @@ -1038,7 +1074,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1156,10 +1192,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) + if forwarderPort == 0 { + forwarderPort = nbdns.ForwarderClientPort + } + dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0), + ForwarderPort: forwarderPort, } for _, zone := range protoDNSConfig.GetCustomZones() { @@ -1316,7 +1358,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers func (e *Engine) receiveSignalEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() // connect to a stream of messages coming from the signal server err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { e.syncMsgMux.Lock() @@ -1613,7 +1657,7 @@ func (e *Engine) getRosenpassAddr() string { // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // and updates the status recorder with the latest states. -func (e *Engine) RunHealthProbes() bool { +func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() signalHealthy := e.signal.IsHealthy() @@ -1645,8 +1689,12 @@ func (e *Engine) RunHealthProbes() bool { } e.syncMsgMux.Unlock() - - results := e.probeICE(stuns, turns) + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) + } e.statusRecorder.UpdateRelayStates(results) relayHealthy := true @@ -1663,15 +1711,10 @@ func (e *Engine) RunHealthProbes() bool { return allHealthy } -func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { - return append( - relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), - relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., - ) -} - -// restartEngine restarts the engine by cancelling the client context -func (e *Engine) restartEngine() { +// triggerClientRestart triggers a full client restart by cancelling the client context. +// Note: This does NOT just restart the engine - it cancels the entire client context, +// which causes the connect client's retry loop to create a completely new engine. +func (e *Engine) triggerClientRestart() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -1693,7 +1736,9 @@ func (e *Engine) startNetworkMonitor() { } e.networkMonitor = networkmonitor.New() + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() if err := e.networkMonitor.Listen(e.ctx); err != nil { if errors.Is(err, context.Canceled) { log.Infof("network monitor stopped") @@ -1703,8 +1748,8 @@ func (e *Engine) startNetworkMonitor() { return } - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() + log.Infof("Network monitor: detected network change, triggering client restart") + e.triggerClientRestart() }() } @@ -1789,64 +1834,50 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return } - if forwarderPort > 0 { - dnsfwd.SetListenPort(forwarderPort) - } - if !enabled { - if e.dnsForwardMgr == nil { - return - } - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() return } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - - default: + if e.dnsForwardMgr == nil { + e.startDNSForwarder(fwdEntries) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil + e.stopDNSForwarder() } - } -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) +func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil + return } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) +} + +func (e *Engine) stopDNSForwarder() { + if e.dnsForwardMgr == nil { + return + } + + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + + e.dnsForwardMgr = nil } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index e3b188468..7752c97b0 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -24,6 +24,7 @@ import ( // Manager handles netflow tracking and logging type Manager struct { mux sync.Mutex + shutdownWg sync.WaitGroup logger nftypes.FlowLogger flowConfig *nftypes.FlowConfig conntrack nftypes.ConnTracker @@ -105,8 +106,15 @@ func (m *Manager) resetClient() error { ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel - go m.receiveACKs(ctx, flowClient) - go m.startSender(ctx) + m.shutdownWg.Add(2) + go func() { + defer m.shutdownWg.Done() + m.receiveACKs(ctx, flowClient) + }() + go func() { + defer m.shutdownWg.Done() + m.startSender(ctx) + }() return nil } @@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error { // Close cleans up all resources func (m *Manager) Close() { m.mux.Lock() - defer m.mux.Unlock() - if err := m.disableFlow(); err != nil { log.Warnf("failed to disable flow manager: %v", err) } + m.mux.Unlock() + + m.shutdownWg.Wait() } // GetLogger returns the flow logger diff --git a/client/internal/networkmonitor/check_change_bsd.go b/client/internal/networkmonitor/check_change_bsd.go index f5eb2c739..b3482f54e 100644 --- a/client/internal/networkmonitor/check_change_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -1,4 +1,4 @@ -//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package networkmonitor @@ -6,21 +6,19 @@ import ( "context" "errors" "fmt" - "syscall" - "unsafe" log "github.com/sirupsen/logrus" - "golang.org/x/net/route" "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + fd, err := prepareFd() if err != nil { return fmt.Errorf("open routing socket: %v", err) } + defer func() { err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { @@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er } }() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - buf := make([]byte, 2048) - n, err := unix.Read(fd, buf) - if err != nil { - if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Warnf("Network monitor: failed to read from routing socket: %v", err) - } - continue - } - if n < unix.SizeofRtMsghdr { - log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) - continue - } - - msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - - switch msg.Type { - // handle route changes - case unix.RTM_ADD, syscall.RTM_DELETE: - route, err := parseRouteMessage(buf[:n]) - if err != nil { - log.Debugf("Network monitor: error parsing routing message: %v", err) - continue - } - - if route.Dst.Bits() != 0 { - continue - } - - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - } - switch msg.Type { - case unix.RTM_ADD: - log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - return nil - case unix.RTM_DELETE: - if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - return nil - } - } - } - } - } -} - -func parseRouteMessage(buf []byte) (*systemops.Route, error) { - msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) - if err != nil { - return nil, fmt.Errorf("parse RIB: %v", err) - } - - if len(msgs) != 1 { - return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) - } - - msg, ok := msgs[0].(*route.RouteMessage) - if !ok { - return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) - } - - return systemops.MsgToRoute(msg) + return routeCheck(ctx, fd, nexthopv4, nexthopv6) } diff --git a/client/internal/networkmonitor/check_change_common.go b/client/internal/networkmonitor/check_change_common.go new file mode 100644 index 000000000..c287236e8 --- /dev/null +++ b/client/internal/networkmonitor/check_change_common.go @@ -0,0 +1,92 @@ +//go:build dragonfly || freebsd || netbsd || openbsd || darwin + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func prepareFd() (int, error) { + return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) +} + +func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + buf := make([]byte, 2048) + n, err := unix.Read(fd, buf) + if err != nil { + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Warnf("Network monitor: failed to read from routing socket: %v", err) + } + continue + } + if n < unix.SizeofRtMsghdr { + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + continue + } + + msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + + switch msg.Type { + // handle route changes + case unix.RTM_ADD, syscall.RTM_DELETE: + route, err := parseRouteMessage(buf[:n]) + if err != nil { + log.Debugf("Network monitor: error parsing routing message: %v", err) + continue + } + + if route.Dst.Bits() != 0 { + continue + } + + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + } + switch msg.Type { + case unix.RTM_ADD: + log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + return nil + case unix.RTM_DELETE: + if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) + return nil + } + } + } + } + } +} + +func parseRouteMessage(buf []byte) (*systemops.Route, error) { + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + if len(msgs) != 1 { + return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + } + + msg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + } + + return systemops.MsgToRoute(msg) +} diff --git a/client/internal/networkmonitor/check_change_darwin.go b/client/internal/networkmonitor/check_change_darwin.go new file mode 100644 index 000000000..ddc6e1736 --- /dev/null +++ b/client/internal/networkmonitor/check_change_darwin.go @@ -0,0 +1,149 @@ +//go:build darwin && !ios + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "hash/fnv" + "os/exec" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// todo: refactor to not use static functions + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + fd, err := prepareFd() + if err != nil { + return fmt.Errorf("open routing socket: %v", err) + } + + defer func() { + if err := unix.Close(fd); err != nil { + if !errors.Is(err, unix.EBADF) { + log.Warnf("Network monitor: failed to close routing socket: %v", err) + } + } + }() + + routeChanged := make(chan struct{}) + go func() { + _ = routeCheck(ctx, fd, nexthopv4, nexthopv6) + close(routeChanged) + }() + + wakeUp := make(chan struct{}) + go func() { + wakeUpListen(ctx) + close(wakeUp) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-routeChanged: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("route change detected") + return nil + case <-wakeUp: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("wakeup detected") + return nil + } +} + +func wakeUpListen(ctx context.Context) { + log.Infof("start to watch for system wakeups") + var ( + initialHash uint32 + err error + ) + + // Keep retrying until initial sysctl succeeds or context is canceled + for { + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + default: + initialHash, err = readSleepTimeHash() + if err != nil { + log.Errorf("failed to detect initial sleep time: %v", err) + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + case <-time.After(3 * time.Second): + continue + } + } + log.Debugf("initial wakeup hash: %d", initialHash) + break + } + break + } + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info("context canceled, stopping wakeUpListen") + return + + case <-ticker.C: + newHash, err := readSleepTimeHash() + if err != nil { + log.Errorf("failed to read sleep time hash: %v", err) + continue + } + + if newHash == initialHash { + log.Tracef("no wakeup detected") + continue + } + + upOut, err := exec.Command("uptime").Output() + if err != nil { + log.Errorf("failed to run uptime command: %v", err) + upOut = []byte("unknown") + } + log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut) + return + } + } +} + +func readSleepTimeHash() (uint32, error) { + cmd := exec.Command("sysctl", "kern.sleeptime") + out, err := cmd.Output() + if err != nil { + return 0, fmt.Errorf("failed to run sysctl: %w", err) + } + + h, err := hash(out) + if err != nil { + return 0, fmt.Errorf("failed to compute hash: %w", err) + } + + return h, nil +} + +func hash(data []byte) (uint32, error) { + hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher + if _, err := hasher.Write(data); err != nil { + return 0, err + } + return hasher.Sum32(), nil +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index accdd9c9d..6d019258d 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { event := make(chan struct{}, 1) go nw.checkChanges(ctx, event, nexthop4, nexthop6) + log.Infof("start watching for network changes") // debounce changes timer := time.NewTimer(0) timer.Stop() diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 686430752..6f4f5ad4f 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -19,11 +19,10 @@ type SRWatcher struct { signalClient chNotifier relayManager chNotifier - listeners map[chan struct{}]struct{} - mu sync.Mutex - iFaceDiscover stdnet.ExternalIFaceDiscover - iceConfig ice.Config - + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config cancelIceMonitor context.CancelFunc } diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..693ea1f31 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/sha256" + "errors" "fmt" "net" "sync" @@ -15,6 +17,15 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +const ( + DefaultCacheTTL = 20 * time.Second + probeTimeout = 6 * time.Second +) + +var ( + ErrCheckInProgress = errors.New("probe check is already in progress") +) + // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI string @@ -22,8 +33,164 @@ type ProbeResult struct { Addr string } +type StunTurnProbe struct { + cacheResults []ProbeResult + cacheTimestamp time.Time + cacheKey string + cacheTTL time.Duration + probeInProgress bool + probeDone chan struct{} + mu sync.Mutex +} + +func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe { + return &StunTurnProbe{ + cacheTTL: cacheTTL, + } +} + +func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + if p.probeInProgress { + doneChan := p.probeDone + p.mu.Unlock() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-doneChan: + return p.getCachedResults(cacheKey, stuns, turns) + } + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + p.mu.Unlock() + + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + + return p.getCachedResults(cacheKey, stuns, turns) +} + +// ProbeAll probes all given servers asynchronously and returns the results +func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + + if results := p.checkCache(cacheKey); results != nil { + p.mu.Unlock() + return results + } + + if p.probeInProgress { + p.mu.Unlock() + return createErrorResults(stuns, turns) + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + log.Infof("started new probe for STUN, TURN servers") + go func() { + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + }() + + p.mu.Unlock() + + timer := time.NewTimer(1300 * time.Millisecond) + defer timer.Stop() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-probeDone: + // when the probe is return fast, return the results right away + return p.getCachedResults(cacheKey, stuns, turns) + case <-timer.C: + // if the probe takes longer than 1.3s, return error results to avoid blocking + return createErrorResults(stuns, turns) + } +} + +func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult { + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + age := time.Since(p.cacheTimestamp) + if age < p.cacheTTL { + results := append([]ProbeResult(nil), p.cacheResults...) + log.Debugf("returning cached probe results (age: %v)", age) + return results + } + } + return nil +} + +func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + p.mu.Lock() + defer p.mu.Unlock() + + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + return append([]ProbeResult(nil), p.cacheResults...) + } + return createErrorResults(stuns, turns) +} + +func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) { + defer func() { + p.mu.Lock() + p.probeInProgress = false + p.mu.Unlock() + }() + results := make([]ProbeResult, len(stuns)+len(turns)) + + var wg sync.WaitGroup + for i, uri := range stuns { + wg.Add(1) + go func(idx int, stunURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = stunURI.String() + results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI) + }(i, uri) + } + + stunOffset := len(stuns) + for i, uri := range turns { + wg.Add(1) + go func(idx int, turnURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = turnURI.String() + results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI) + }(stunOffset+i, uri) + } + + wg.Wait() + + p.mu.Lock() + p.cacheResults = results + p.cacheTimestamp = time.Now() + p.cacheKey = cacheKey + p.mu.Unlock() + + log.Debug("Stored new probe results in cache") +} + // ProbeSTUN tries binding to the given STUN uri and acquiring an address -func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("stun probe error from %s: %s", uri, probeErr) @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } // ProbeTURN tries allocating a session from the given TURN URI -func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("turn probe error from %s: %s", uri, probeErr) @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) return relayConn.LocalAddr().String(), nil } -// ProbeAll probes all given servers asynchronously and returns the results -func ProbeAll( - ctx context.Context, - fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), - relays []*stun.URI, -) []ProbeResult { - results := make([]ProbeResult, len(relays)) +func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + total := len(stuns) + len(turns) + results := make([]ProbeResult, total) - var wg sync.WaitGroup - for i, uri := range relays { - ctx, cancel := context.WithTimeout(ctx, 6*time.Second) - defer cancel() - - wg.Add(1) - go func(res *ProbeResult, stunURI *stun.URI) { - defer wg.Done() - res.URI = stunURI.String() - res.Addr, res.Err = fn(ctx, stunURI) - }(&results[i], uri) + allURIs := append(append([]*stun.URI{}, stuns...), turns...) + for i, uri := range allURIs { + results[i] = ProbeResult{ + URI: uri.String(), + Err: ErrCheckInProgress, + } } - wg.Wait() - return results } + +func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string { + h := sha256.New() + for _, uri := range stuns { + h.Write([]byte(uri.String())) + } + for _, uri := range turns { + h.Write([]byte(uri.String())) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go index def18411f..8b5407850 100644 --- a/client/internal/routemanager/common/params.go +++ b/client/internal/routemanager/common/params.go @@ -1,6 +1,7 @@ package common import ( + "sync/atomic" "time" "github.com/netbirdio/netbird/client/firewall/manager" @@ -25,4 +26,5 @@ type HandlerParams struct { UseNewDNSRoute bool Firewall manager.Manager FakeIPManager *fakeip.Manager + ForwarderPort *atomic.Uint32 } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..348338dac 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -8,6 +8,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -18,7 +19,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/common" @@ -55,6 +55,7 @@ type DnsInterceptor struct { peerStore *peerstore.Store firewall firewall.Manager fakeIPManager *fakeip.Manager + forwarderPort *atomic.Uint32 } func New(params common.HandlerParams) *DnsInterceptor { @@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor { firewall: params.Firewall, fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), + forwarderPort: params.ForwarderPort, } } @@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..26cf758d9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -10,6 +10,7 @@ import ( "runtime" "slices" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -23,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" @@ -54,6 +56,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string SetFirewall(firewall.Manager) error + SetDNSForwarderPort(port uint16) Stop(stateManager *statemanager.Manager) } @@ -78,6 +81,7 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex + shutdownWg sync.WaitGroup clientNetworks map[route.HAUniqueID]*client.Watcher routeSelector *routeselector.RouteSelector serverRouter *server.Router @@ -101,12 +105,13 @@ type DefaultManager struct { disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler fakeIPManager *fakeip.Manager + dnsForwarderPort atomic.Uint32 } func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) @@ -130,6 +135,7 @@ func NewManager(config ManagerConfig) *DefaultManager { disableServerRoutes: config.DisableServerRoutes, activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), } + dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort)) useNoop := netstack.IsEnabled() || config.DisableClientRoutes dm.setupRefCounters(useNoop) @@ -270,9 +276,15 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { return nil } +// SetDNSForwarderPort sets the DNS forwarder port for route handlers +func (m *DefaultManager) SetDNSForwarderPort(port uint16) { + m.dnsForwarderPort.Store(uint32(port)) +} + // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() + m.shutdownWg.Wait() if m.serverRouter != nil { m.serverRouter.CleanUp() } @@ -345,6 +357,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { UseNewDNSRoute: m.useNewDNSRoute, Firewall: m.firewall, FakeIPManager: m.fakeIPManager, + ForwarderPort: &m.dnsForwarderPort, } handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { @@ -474,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { } clientNetworkWatcher := client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) } @@ -516,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout } clientNetworkWatcher = client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() } update := client.RoutesUpdate{ UpdateSerial: updateSerial, diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index be633c3fa..6b06144b2 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } +// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface +func (m *MockManager) SetDNSForwarderPort(port uint16) { +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..d9b109beb 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index e4a78599e..61c8bbc79 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -9,8 +9,6 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" ) @@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { defer rs.mu.RUnlock() if rs.deselectAll { - log.Debugf("Route %s not selected (deselect all)", routeID) return false } _, deselected := rs.deselectedRoutes[routeID] isSelected := !deselected - log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected) return isSelected } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/internal/winregistry/volatile_windows.go b/client/internal/winregistry/volatile_windows.go new file mode 100644 index 000000000..a8e350fe7 --- /dev/null +++ b/client/internal/winregistry/volatile_windows.go @@ -0,0 +1,59 @@ +package winregistry + +import ( + "syscall" + "unsafe" + + "golang.org/x/sys/windows/registry" +) + +var ( + advapi = syscall.NewLazyDLL("advapi32.dll") + regCreateKeyExW = advapi.NewProc("RegCreateKeyExW") +) + +const ( + // Registry key options + regOptionNonVolatile = 0x0 // Key is preserved when system is rebooted + regOptionVolatile = 0x1 // Key is not preserved when system is rebooted + + // Registry disposition values + regCreatedNewKey = 0x1 + regOpenedExistingKey = 0x2 +) + +// CreateVolatileKey creates a volatile registry key named path under open key root. +// CreateVolatileKey returns the new key and a boolean flag that reports whether the key already existed. +// The access parameter specifies the access rights for the key to be created. +// +// Volatile keys are stored in memory and are automatically deleted when the system is shut down. +// This provides automatic cleanup without requiring manual registry maintenance. +func CreateVolatileKey(root registry.Key, path string, access uint32) (registry.Key, bool, error) { + pathPtr, err := syscall.UTF16PtrFromString(path) + if err != nil { + return 0, false, err + } + + var ( + handle syscall.Handle + disposition uint32 + ) + + ret, _, _ := regCreateKeyExW.Call( + uintptr(root), + uintptr(unsafe.Pointer(pathPtr)), + 0, // reserved + 0, // class + uintptr(regOptionVolatile), // options - volatile key + uintptr(access), // desired access + 0, // security attributes + uintptr(unsafe.Pointer(&handle)), + uintptr(unsafe.Pointer(&disposition)), + ) + + if ret != 0 { + return 0, false, syscall.Errno(ret) + } + + return registry.Key(handle), disposition == regOpenedExistingKey, nil +} diff --git a/client/net/conn.go b/client/net/conn.go index 918e7f628..bf54c792d 100644 --- a/client/net/conn.go +++ b/client/net/conn.go @@ -17,8 +17,7 @@ type Conn struct { ID hooks.ConnectionID } -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection. func (c *Conn) Close() error { return closeConn(c.ID, c.Conn) } @@ -29,7 +28,7 @@ type TCPConn struct { ID hooks.ConnectionID } -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection. func (c *TCPConn) Close() error { return closeConn(c.ID, c.TCPConn) } @@ -37,13 +36,16 @@ func (c *TCPConn) Close() error { // closeConn is a helper function to close connections and execute close hooks. func closeConn(id hooks.ConnectionID, conn io.Closer) error { err := conn.Close() + cleanupConnID(id) + return err +} +// cleanupConnID executes close hooks for a connection ID. +func cleanupConnID(id hooks.ConnectionID) { closeHooks := hooks.GetCloseHooks() for _, hook := range closeHooks { if err := hook(id); err != nil { log.Errorf("Error executing close hook: %v", err) } } - - return err } diff --git a/client/net/dial.go b/client/net/dial.go index 041a00e5d..17c9ff98a 100644 --- a/client/net/dial.go +++ b/client/net/dial.go @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro } return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil } - if err := conn.Close(); err != nil { log.Errorf("failed to close connection: %v", err) } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go index 2e1eb53d8..1e275013f 100644 --- a/client/net/dialer_dial.go +++ b/client/net/dialer_dial.go @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { + cleanupConnID(connID) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str ips, err := resolver.LookupIPAddr(ctx, host) if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) + return fmt.Errorf("resolve address %s: %w", address, err) } log.Debugf("Dialer resolved IPs for %s: %v", address, ips) diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go index 0bb5ad67d..a150172b4 100644 --- a/client/net/listener_listen.go +++ b/client/net/listener_listen.go @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection. func (c *PacketConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.PacketConn) @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.UDPConn.WriteTo(b, addr) } -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection. func (c *UDPConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.UDPConn) diff --git a/client/server/server.go b/client/server/server.go index b3fdb2f1f..76f44cd13 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -1070,10 +1070,7 @@ func (s *Server) Status( s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) if msg.GetFullPeerStatus { - if msg.ShouldRunProbes { - s.runProbes() - } - + s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() @@ -1266,7 +1263,7 @@ func isUnixRunningDesktop() bool { return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != "" } -func (s *Server) runProbes() { +func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return } @@ -1277,7 +1274,7 @@ func (s *Server) runProbes() { } if time.Since(s.lastProbe) > probeThreshold { - if engine.RunHealthProbes() { + if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } } diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go index 7199ad036..7cd7412f0 100644 --- a/client/ssh/server/command_execution.go +++ b/client/ssh/server/command_execution.go @@ -14,7 +14,6 @@ import ( // handleCommand executes an SSH command with privilege validation func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) { - localUser := privilegeResult.User hasPty := winCh != nil commandType := "command" @@ -22,7 +21,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege commandType = "Pty command" } - logger.Infof("executing %s for %s from %s: %s", commandType, localUser.Username, session.RemoteAddr(), safeLogCommand(session.Command())) + logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command())) execCmd, err := s.createCommand(privilegeResult, session, hasPty) if err != nil { @@ -38,7 +37,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege logger.Debugf(errWriteSession, writeErr) } if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return } @@ -74,7 +73,7 @@ func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd if err != nil { logger.Errorf("create stdin pipe: %v", err) if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return false } @@ -93,7 +92,7 @@ func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd logger.Errorf("command start failed: %v", err) // no user message for exec failure, just exit if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return false } @@ -135,7 +134,7 @@ func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, e } if err := session.Exit(130); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return false @@ -160,7 +159,7 @@ func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) { if err == nil { if err := session.Exit(0); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return } @@ -168,12 +167,12 @@ func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.E var exitError *exec.ExitError if errors.As(err, &exitError) { if err := session.Exit(exitError.ExitCode()); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } } else { logger.Debugf("non-exit error in command execution: %v", err) if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } } } diff --git a/client/ssh/server/command_execution_unix.go b/client/ssh/server/command_execution_unix.go index f78678743..2b9db863b 100644 --- a/client/ssh/server/command_execution_unix.go +++ b/client/ssh/server/command_execution_unix.go @@ -98,10 +98,6 @@ func (pm *ptyManager) File() *os.File { } func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { - cmd := session.Command() - localUser := privilegeResult.User - logger.Infof("executing Pty command for %s from %s: %s", localUser.Username, session.RemoteAddr(), safeLogCommand(cmd)) - execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session) if err != nil { logger.Errorf("Pty command creation failed: %v", err) @@ -110,16 +106,24 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResu logger.Debugf(errWriteSession, writeErr) } if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return false } + shell := execCmd.Path + cmd := session.Command() + if len(cmd) == 0 { + logger.Infof("starting interactive shell: %s", shell) + } else { + logger.Infof("executing command: %s", safeLogCommand(cmd)) + } + ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq) if err != nil { logger.Errorf("Pty start failed: %v", err) if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return false } @@ -178,18 +182,22 @@ func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *pty go func() { if _, err := io.Copy(ptmx, session); err != nil { - logger.Debugf("Pty input copy error: %v", err) + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty input copy error: %v", err) + } } }() go func() { defer func() { - if err := session.Close(); err != nil { + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { logger.Debugf("session close error: %v", err) } }() if _, err := io.Copy(session, ptmx); err != nil { - logger.Debugf("Pty output copy error: %v", err) + if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) { + logger.Warnf("Pty output copy error: %v", err) + } } }() } @@ -229,7 +237,7 @@ func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Ses } if err := session.Exit(130); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } } @@ -243,7 +251,7 @@ func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Sessi // Normal completion logger.Debugf("Pty command completed successfully") if err := session.Exit(0); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } } diff --git a/client/ssh/server/command_execution_windows.go b/client/ssh/server/command_execution_windows.go index 7b7a6a492..c646c36b8 100644 --- a/client/ssh/server/command_execution_windows.go +++ b/client/ssh/server/command_execution_windows.go @@ -266,9 +266,14 @@ func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) [] } func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool { - localUser := privilegeResult.User cmd := session.Command() - logger.Infof("executing Pty command for %s from %s: %s", localUser.Username, session.RemoteAddr(), safeLogCommand(cmd)) + shell := getUserShell(privilegeResult.User.Uid) + + if len(cmd) == 0 { + logger.Infof("starting interactive shell: %s", shell) + } else { + logger.Infof("executing command: %s", safeLogCommand(cmd)) + } // Always use user switching on Windows - no direct execution s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd) @@ -317,7 +322,7 @@ func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Sessi logger.Debugf(errWriteSession, writeErr) } if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return } diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index 71a988041..80e2ae141 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "io" "net" "net/netip" "strings" @@ -80,10 +81,17 @@ func (e *UserNotFoundError) Unwrap() error { return e.Cause } +// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors +func logSessionExitError(logger *log.Entry, err error) { + if err != nil && !errors.Is(err, io.EOF) { + logger.Warnf(errExitSession, err) + } +} + // safeLogCommand returns a safe representation of the command for logging func safeLogCommand(cmd []string) string { if len(cmd) == 0 { - return "" + return "" } if len(cmd) == 1 { return cmd[0] @@ -520,7 +528,7 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { return nil } - log.Infof("SSH connection from NetBird peer %s allowed", remoteIP) + log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr) return conn } @@ -623,19 +631,19 @@ func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, s.mu.RUnlock() if !allowLocal { - log.Debugf("direct-tcpip rejected: local port forwarding disabled") + log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port) _ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled") return } // Check privilege requirements for the destination port if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil { - log.Infof("direct-tcpip denied: %v", err) + log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err) _ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges") return } - log.Debugf("direct-tcpip request: %s:%d", payload.Host, payload.Port) + log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port) ssh.DirectTCPIPHandler(srv, conn, newChan, ctx) } diff --git a/client/ssh/server/session_handlers.go b/client/ssh/server/session_handlers.go index 07e4051a7..8803de50a 100644 --- a/client/ssh/server/session_handlers.go +++ b/client/ssh/server/session_handlers.go @@ -16,22 +16,19 @@ import ( // sessionHandler handles SSH sessions func (s *Server) sessionHandler(session ssh.Session) { sessionKey := s.registerSession(session) + logger := log.WithField("session", sessionKey) + logger.Infof("SSH session started") sessionStart := time.Now() - logger := log.WithField("session", sessionKey) defer s.unregisterSession(sessionKey, session) defer func() { - duration := time.Since(sessionStart) - if err := session.Close(); err != nil { - logger.Debugf("close session after %v: %v", duration, err) - return + duration := time.Since(sessionStart).Round(time.Millisecond) + if err := session.Close(); err != nil && !errors.Is(err, io.EOF) { + logger.Warnf("close session after %v: %v", duration, err) } - - logger.Debugf("session closed after %v", duration) + logger.Infof("SSH session closed after %v", duration) }() - logger.Infof("establishing SSH session for %s from %s", session.User(), session.RemoteAddr()) - privilegeResult, err := s.userPrivilegeCheck(session.User()) if err != nil { s.handlePrivError(logger, session, err) @@ -61,7 +58,7 @@ func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) { logger.Debugf(errWriteSession, err) } if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr()) } @@ -86,7 +83,6 @@ func (s *Server) registerSession(session ssh.Session) SessionKey { s.sessions[sessionKey] = session s.mu.Unlock() - log.WithField("session", sessionKey).Debugf("registered SSH session") return sessionKey } @@ -111,7 +107,6 @@ func (s *Server) unregisterSession(sessionKey SessionKey, _ ssh.Session) { } s.mu.Unlock() - log.WithField("session", sessionKey).Debugf("unregistered SSH session") } func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) { @@ -123,7 +118,7 @@ func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err err logger.Debugf(errWriteSession, writeErr) } if exitErr := session.Exit(1); exitErr != nil { - logger.Debugf(errExitSession, exitErr) + logSessionExitError(logger, exitErr) } } diff --git a/client/ssh/server/session_handlers_js.go b/client/ssh/server/session_handlers_js.go index bca97ded5..c35e4da0b 100644 --- a/client/ssh/server/session_handlers_js.go +++ b/client/ssh/server/session_handlers_js.go @@ -16,7 +16,7 @@ func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCh logger.Debugf(errWriteSession, err) } if err := session.Exit(1); err != nil { - logger.Debugf(errExitSession, err) + logSessionExitError(logger, err) } return false } diff --git a/client/status/status.go b/client/status/status.go index 5e4fcd8dc..8a0b7bae0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + probeRelay "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" @@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, for _, relay := range overview.Relays.Details { available := "Available" reason := "" + if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) + if relay.Error == probeRelay.ErrCheckInProgress.Error() { + available = "Checking..." + } else { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 8f7845c74..56a25ab19 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -307,6 +307,8 @@ type serviceClient struct { mExitNodeDeselectAll *systray.MenuItem logFile string wLoginURL fyne.Window + + connectCancel context.CancelFunc } type menuHandler struct { @@ -698,17 +700,15 @@ func (s *serviceClient) hasSSHChanges() bool { s.disableSSHAuth != s.sDisableSSHAuth.Checked } -func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { +func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return nil, err + return nil, fmt.Errorf("get daemon client: %w", err) } activeProf, err := s.profileManager.GetActiveProfile() if err != nil { - log.Errorf("get active profile: %v", err) - return nil, err + return nil, fmt.Errorf("get active profile: %w", err) } currUser, err := user.Current() @@ -716,84 +716,71 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + loginResp, err := conn.Login(ctx, &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, }) if err != nil { - log.Errorf("login to management URL with: %v", err) - return nil, err + return nil, fmt.Errorf("login to management: %w", err) } if loginResp.NeedsSSOLogin && openURL { - err = s.handleSSOLogin(loginResp, conn) - if err != nil { - log.Errorf("handle SSO login failed: %v", err) - return nil, err + if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil { + return nil, fmt.Errorf("SSO login: %w", err) } } return loginResp, nil } -func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := openURL(loginResp.VerificationURIComplete) - if err != nil { - log.Errorf("opening the verification uri in the browser failed: %v", err) - return err +func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { + if err := openURL(loginResp.VerificationURIComplete); err != nil { + return fmt.Errorf("open browser: %w", err) } - resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) + resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) if err != nil { - log.Errorf("waiting sso login failed with: %v", err) - return err + return fmt.Errorf("wait for SSO login: %w", err) } if resp.Email != "" { - err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ + if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ Email: resp.Email, - }) - if err != nil { - log.Warnf("failed to set profile state: %v", err) + }); err != nil { + log.Debugf("failed to set profile state: %v", err) } else { s.mProfile.refresh() } - } return nil } -func (s *serviceClient) menuUpClick() error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { systray.SetTemplateIcon(iconErrorMacOS, s.icError) - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } - _, err = s.login(true) + _, err = s.login(ctx, true) if err != nil { - log.Errorf("login failed with: %v", err) - return err + return fmt.Errorf("login: %w", err) } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status == string(internal.StatusConnected) { - log.Warnf("already connected") return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - log.Errorf("up service: %v", err) - return err + if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("start connection: %w", err) } return nil @@ -803,24 +790,20 @@ func (s *serviceClient) menuDownClick() error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { - log.Warnf("already down") return nil } - if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("down service: %v", err) - return err + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("stop connection: %w", err) } return nil @@ -956,6 +939,7 @@ func (s *serviceClient) onTrayReady() { newProfileMenuArgs := &newProfileMenuArgs{ ctx: s.ctx, + serviceClient: s, profileManager: s.profileManager, eventHandler: s.eventHandler, profileMenuItem: profileMenuItem, @@ -1534,7 +1518,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - resp, err := s.login(false) + resp, err := s.login(ctx, false) if err != nil { log.Errorf("failed to fetch login URL: %v", err) return @@ -1554,7 +1538,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) + _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) if err != nil { log.Errorf("Waiting sso login failed with: %v", err) label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") @@ -1562,7 +1546,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } label.SetText("Re-authentication successful.\nReconnecting") - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("get service status: %v", err) return @@ -1575,7 +1559,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.Up(s.ctx, &proto.UpRequest{}) + _, err = conn.Up(ctx, &proto.UpRequest{}) if err != nil { label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") log.Errorf("Reconnecting failed with: %v", err) diff --git a/client/ui/debug.go b/client/ui/debug.go index 76afc7753..bf9839dda 100644 --- a/client/ui/debug.go +++ b/client/ui/debug.go @@ -18,6 +18,7 @@ import ( "github.com/skratchdot/open-golang/open" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/proto" nbstatus "github.com/netbirdio/netbird/client/status" uptypes "github.com/netbirdio/netbird/upload-server/types" @@ -426,6 +427,12 @@ func (s *serviceClient) collectDebugData( return "", err } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + postUpStatus, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("Failed to get post-up status: %v", err) @@ -433,7 +440,7 @@ func (s *serviceClient) collectDebugData( var postUpStatusOutput string if postUpStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(postUpStatus, params.anonymize, "", nil, nil, nil, "", profName) postUpStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPostUp := fmt.Sprintf("----- NetBird post-up - Timestamp: %s", time.Now().Format(time.RFC3339)) @@ -450,7 +457,7 @@ func (s *serviceClient) collectDebugData( var preDownStatusOutput string if preDownStatus != nil { - overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(preDownStatus, params.anonymize, "", nil, nil, nil, "", profName) preDownStatusOutput = nbstatus.ParseToFullDetailSummary(overview) } headerPreDown := fmt.Sprintf("----- NetBird pre-down - Timestamp: %s - Duration: %s", @@ -574,6 +581,12 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa return nil, fmt.Errorf("get client: %v", err) } + pm := profilemanager.NewProfileManager() + var profName string + if activeProf, err := pm.GetActiveProfile(); err == nil { + profName = activeProf.Name + } + statusResp, err := conn.Status(s.ctx, &proto.StatusRequest{GetFullPeerStatus: true}) if err != nil { log.Warnf("failed to get status for debug bundle: %v", err) @@ -581,7 +594,7 @@ func (s *serviceClient) createDebugBundle(anonymize bool, systemInfo bool, uploa var statusOutput string if statusResp != nil { - overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", "") + overview := nbstatus.ConvertToStatusOutputOverview(statusResp, anonymize, "", nil, nil, nil, "", profName) statusOutput = nbstatus.ParseToFullDetailSummary(overview) } diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index 348299e67..e0b619411 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" @@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) { func (h *eventHandler) handleConnectClick() { h.client.mUp.Disable() + + if h.client.connectCancel != nil { + h.client.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(h.client.ctx) + h.client.connectCancel = connectCancel + go func() { - defer h.client.mUp.Enable() - if err := h.client.menuUpClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) + defer connectCancel() + + if err := h.client.menuUpClick(connectCtx); err != nil { + st, ok := status.FromError(err) + if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { + log.Debugf("connect operation cancelled by user") + } else { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + log.Errorf("connect failed: %v", err) + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after connect: %v", err) } }() } func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() + + if h.client.connectCancel != nil { + log.Debugf("cancelling ongoing connect operation") + h.client.connectCancel() + h.client.connectCancel = nil + } + go func() { - defer h.client.mDown.Enable() if err := h.client.menuDownClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) + st, ok := status.FromError(err) + if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + log.Errorf("disconnect failed: %v", err) + } else { + log.Debugf("disconnect cancelled or already disconnecting") + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after disconnect: %v", err) } }() } diff --git a/client/ui/profile.go b/client/ui/profile.go index 075223795..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -387,6 +387,7 @@ type subItem struct { type profileMenu struct { mu sync.Mutex ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem @@ -396,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -404,12 +405,13 @@ type profileMenu struct { type newProfileMenuArgs struct { ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -418,6 +420,7 @@ type newProfileMenuArgs struct { func newProfileMenu(args newProfileMenuArgs) *profileMenu { p := profileMenu{ ctx: args.ctx, + serviceClient: args.serviceClient, profileManager: args.profileManager, eventHandler: args.eventHandler, profileMenuItem: args.profileMenuItem, @@ -569,10 +572,19 @@ func (p *profileMenu) refresh() { } } - if err := p.upClickCallback(); err != nil { + if p.serviceClient.connectCancel != nil { + p.serviceClient.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(p.ctx) + p.serviceClient.connectCancel = connectCancel + + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } + connectCancel() + p.refresh() p.loadSettingsCallback() } diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..cf089d4ed 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" @@ -31,6 +35,8 @@ type Config struct { NameServerGroups []*NameServerGroup // CustomZones contains a list of custom zone CustomZones []CustomZone + // ForwarderPort is the port clients should connect to on routing peers for DNS forwarding + ForwarderPort uint16 } // CustomZone represents a custom zone to be resolved by the dns server diff --git a/go.mod b/go.mod index 1a3e500e8..1daf73844 100644 --- a/go.mod +++ b/go.mod @@ -63,7 +63,7 @@ require ( github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 @@ -78,7 +78,7 @@ require ( github.com/pion/turn/v3 v3.0.1 github.com/pkg/sftp v1.13.9 github.com/prometheus/client_golang v1.22.0 - github.com/quic-go/quic-go v0.48.2 + github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 @@ -245,7 +245,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.4.0 // indirect + go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.28.0 // indirect diff --git a/go.sum b/go.sum index 1bc3c24ca..e5c3aa0ba 100644 --- a/go.sum +++ b/go.sum @@ -508,8 +508,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f h1:XIpRDlpPz3zFUkpwaqDRHjwpQRsf2ZKHggoex1MTafs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251010134843-7af36217ac1f/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -597,8 +597,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= -github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= +github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -756,8 +756,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= diff --git a/infrastructure_files/configure.sh b/infrastructure_files/configure.sh index e3fcbfdde..92252d0b3 100755 --- a/infrastructure_files/configure.sh +++ b/infrastructure_files/configure.sh @@ -185,12 +185,15 @@ if [[ "$NETBIRD_DISABLE_LETSENCRYPT" == "true" ]]; then echo "You are also free to remove any occurrences of the Letsencrypt-volume $LETSENCRYPT_VOLUMENAME" echo "" - export NETBIRD_SIGNAL_PROTOCOL="https" unset NETBIRD_LETSENCRYPT_DOMAIN unset NETBIRD_MGMT_API_CERT_FILE unset NETBIRD_MGMT_API_CERT_KEY_FILE fi +if [[ -n "$NETBIRD_MGMT_API_CERT_FILE" && -n "$NETBIRD_MGMT_API_CERT_KEY_FILE" ]]; then + export NETBIRD_SIGNAL_PROTOCOL="https" +fi + # Check if management identity provider is set if [ -n "$NETBIRD_MGMT_IDP" ]; then EXTRA_CONFIG={} diff --git a/infrastructure_files/docker-compose.yml.tmpl b/infrastructure_files/docker-compose.yml.tmpl index b24e853b4..2bc49d3e5 100644 --- a/infrastructure_files/docker-compose.yml.tmpl +++ b/infrastructure_files/docker-compose.yml.tmpl @@ -40,13 +40,21 @@ services: signal: <<: *default image: netbirdio/signal:$NETBIRD_SIGNAL_TAG + depends_on: + - dashboard volumes: - $SIGNAL_VOLUMENAME:/var/lib/netbird + - $LETSENCRYPT_VOLUMENAME:/etc/letsencrypt:ro ports: - $NETBIRD_SIGNAL_PORT:80 # # port and command for Let's Encrypt validation # - 443:443 # command: ["--letsencrypt-domain", "$NETBIRD_LETSENCRYPT_DOMAIN", "--log-file", "console"] + command: [ + "--cert-file", "$NETBIRD_MGMT_API_CERT_FILE", + "--cert-key", "$NETBIRD_MGMT_API_CERT_KEY_FILE", + "--log-file", "console" + ] # Relay relay: diff --git a/infrastructure_files/getting-started-with-zitadel.sh b/infrastructure_files/getting-started-with-zitadel.sh index bc326cd7e..09c5225ad 100644 --- a/infrastructure_files/getting-started-with-zitadel.sh +++ b/infrastructure_files/getting-started-with-zitadel.sh @@ -682,17 +682,6 @@ renderManagementJson() { "URI": "stun:$NETBIRD_DOMAIN:3478" } ], - "TURNConfig": { - "Turns": [ - { - "Proto": "udp", - "URI": "turn:$NETBIRD_DOMAIN:3478", - "Username": "$TURN_USER", - "Password": "$TURN_PASSWORD" - } - ], - "TimeBasedCredentials": false - }, "Relay": { "Addresses": ["$NETBIRD_RELAY_PROTO://$NETBIRD_DOMAIN:$NETBIRD_PORT"], "CredentialsTTL": "24h", diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index ab9893f27..1b61c081d 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 037d11d11..8251b951f 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -110,7 +110,7 @@ type Manager interface { GetIdpManager() idp.Manager UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) - GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) + GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index 80b165938..ffecb6b8f 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -7,6 +7,7 @@ import ( "path/filepath" "runtime" "strconv" + "time" log "github.com/sirupsen/logrus" "gorm.io/driver/postgres" @@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e return nil, err } - if storeEngine == types.SqliteStoreEngine { - sqlDB.SetMaxOpenConns(1) - } else { - conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) - if err != nil { - conns = runtime.NumCPU() - } - sqlDB.SetMaxOpenConns(conns) + conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) + if err != nil { + conns = runtime.NumCPU() } + if storeEngine == types.SqliteStoreEngine { + conns = 1 + } + + sqlDB.SetMaxOpenConns(conns) + sqlDB.SetMaxIdleConns(conns) + sqlDB.SetConnMaxLifetime(time.Hour) + sqlDB.SetConnMaxIdleTime(3 * time.Minute) + + log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) return db, nil } diff --git a/management/server/dns.go b/management/server/dns.go index 534f43ec6..e5166ce47 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -21,8 +21,8 @@ import ( ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort + oldForwarderPort = nbdns.ForwarderClientPort ) const dnsForwarderPortMinVersion = "v0.59.0" @@ -196,7 +196,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { if len(peers) == 0 { - return oldForwarderPort + return int64(oldForwarderPort) } reqVer := semver.Canonical(requiredVersion) @@ -211,17 +211,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) if peerVersion == "" { // If any peer doesn't have version info, return 0 - return oldForwarderPort + return int64(oldForwarderPort) } // Compare versions if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort + return int64(oldForwarderPort) } } // All peers have the required version or newer - return dnsForwarderPort + return int64(dnsForwarderPort) } // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 7e65b8f92..6b2db1162 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) @@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) } @@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) + result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) { // Test with empty peers list peers := []*nbpeer.Peer{} result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) } @@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) } @@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { + if result != int64(dnsForwarderPort) { t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) } @@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) } @@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) } @@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { + if result == int64(oldForwarderPort) { t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) } @@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) } } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 4b33495de..df89c616c 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -78,7 +78,7 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) @@ -86,7 +86,9 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } _, valid := validPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid)) + reason := invalidPeers[peer.ID] + + util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { @@ -147,16 +149,17 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) util.WriteError(ctx, fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] + reason := invalidPeers[peer.ID] - util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { @@ -240,22 +243,25 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peerToReturn, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } - h.setApprovalRequiredFlag(respBody, validPeersMap) + h.setApprovalRequiredFlag(respBody, validPeersMap, invalidPeersMap) util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) { +func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersMap map[string]struct{}, invalidPeersMap map[string]string) { for _, peer := range respBody { - _, ok := approvedPeersMap[peer.Id] + _, ok := validPeersMap[peer.Id] if !ok { peer.ApprovalRequired = true + + reason := invalidPeersMap[peer.Id] + peer.DisapprovalReason = &reason } } } @@ -304,7 +310,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -430,13 +436,13 @@ func peerToAccessiblePeer(peer *nbpeer.Peer, dnsDomain string) api.AccessiblePee } } -func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool) *api.Peer { +func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, approved bool, reason string) *api.Peer { osVersion := peer.Meta.OSVersion if osVersion == "" { osVersion = peer.Meta.Core } - return &api.Peer{ + apiPeer := &api.Peer{ CreatedAt: peer.CreatedAt, Id: peer.ID, Name: peer.Name, @@ -465,6 +471,12 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD InactivityExpirationEnabled: peer.InactivityExpirationEnabled, Ephemeral: peer.Ephemeral, } + + if !approved { + apiPeer.DisapprovalReason = &reason + } + + return apiPeer } func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsDomain string, accessiblePeersCount int) *api.PeerBatch { diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 34db5f8dc..8a37a8759 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 251c04273..e9a1c8701 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -88,7 +88,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID return true, nil } -func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { var err error var groups []*types.Group var peers []*nbpeer.Peer @@ -96,20 +96,30 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") if err != nil { - return nil, err + return nil, nil, err } settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { - return nil, err + return nil, nil, err } - return am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + validPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, accountID, groups, peers, settings.Extra) + if err != nil { + return nil, nil, err + } + + invalidPeers, err := am.integratedPeerValidator.GetInvalidPeers(ctx, accountID, settings.Extra) + if err != nil { + return nil, nil, err + } + + return validPeers, invalidPeers, nil } type MockIntegratedValidator struct { @@ -136,6 +146,10 @@ func (a MockIntegratedValidator) GetValidatedPeers(_ context.Context, accountID return validatedPeers, nil } +func (a MockIntegratedValidator) GetInvalidPeers(_ context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) { + return make(map[string]string), nil +} + func (MockIntegratedValidator) PreparePeer(_ context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer { return peer } diff --git a/management/server/integrations/integrated_validator/interface.go b/management/server/integrations/integrated_validator/interface.go index be05c2527..26c338cb6 100644 --- a/management/server/integrations/integrated_validator/interface.go +++ b/management/server/integrations/integrated_validator/interface.go @@ -15,6 +15,7 @@ type IntegratedValidator interface { PreparePeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings, temporary bool) *nbpeer.Peer IsNotValidPeer(ctx context.Context, accountID string, peer *nbpeer.Peer, peersGroup []string, extraSettings *types.ExtraSettings) (bool, bool, error) GetValidatedPeers(ctx context.Context, accountID string, groups []*types.Group, peers []*nbpeer.Peer, extraSettings *types.ExtraSettings) (map[string]struct{}, error) + GetInvalidPeers(ctx context.Context, accountID string, extraSettings *types.ExtraSettings) (map[string]string, error) PeerDeleted(ctx context.Context, accountID, peerID string, extraSettings *types.ExtraSettings) error SetPeerInvalidationListener(fn func(accountID string, peerIDs []string)) Stop(ctx context.Context) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 35e16aad1..93fb84e43 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -189,17 +189,17 @@ func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID st panic("implement me") } -func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, error) { +func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { account, err := am.GetAccountFunc(ctx, accountID) if err != nil { - return nil, err + return nil, nil, err } approvedPeers := make(map[string]struct{}) for id := range account.Peers { approvedPeers[id] = struct{}{} } - return approvedPeers, nil + return approvedPeers, nil, nil } // GetGroup mock implementation of GetGroup from server.AccountManager interface diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 5a134eb32..b8c03a4bd 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1161,7 +1161,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 382d026c8..4201b68f6 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -89,8 +89,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met } sql.SetMaxOpenConns(conns) + sql.SetMaxIdleConns(conns) + sql.SetConnMaxLifetime(time.Hour) + sql.SetConnMaxIdleTime(3 * time.Minute) - log.WithContext(ctx).Infof("Set max open db connections to %d", conns) + log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) if skipMigration { log.WithContext(ctx).Infof("skipping migration") diff --git a/management/server/types/account.go b/management/server/types/account.go index f830023c7..50bdc6ab3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -301,7 +301,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -1682,7 +1682,7 @@ func peerSupportsPortRanges(peerVer string) bool { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -1693,6 +1693,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) diff --git a/release_files/install.sh b/release_files/install.sh index 5d5349ec4..6a2c5f458 100755 --- a/release_files/install.sh +++ b/release_files/install.sh @@ -29,6 +29,8 @@ if [ -z ${NETBIRD_RELEASE+x} ]; then NETBIRD_RELEASE=latest fi +TAG_NAME="" + get_release() { local RELEASE=$1 if [ "$RELEASE" = "latest" ]; then @@ -38,17 +40,19 @@ get_release() { local TAG="tags/${RELEASE}" local URL="https://api.github.com/repos/${OWNER}/${REPO}/releases/${TAG}" fi + OUTPUT="" if [ -n "$GITHUB_TOKEN" ]; then - curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -H "Authorization: token ${GITHUB_TOKEN}" -s "${URL}") else - curl -s "${URL}" \ - | grep '"tag_name":' | sed -E 's/.*"([^"]+)".*/\1/' + OUTPUT=$(curl -s "${URL}") fi + TAG_NAME=$(echo ${OUTPUT} | grep -Eo '\"tag_name\":\s*\"v([0-9]+\.){2}[0-9]+"' | tail -n 1) + echo "${TAG_NAME}" | grep -oE 'v[0-9]+\.[0-9]+\.[0-9]+' } download_release_binary() { VERSION=$(get_release "$NETBIRD_RELEASE") + echo "Using the following tag name for binary installation: ${TAG_NAME}" BASE_URL="https://github.com/${OWNER}/${REPO}/releases/download" BINARY_BASE_NAME="${VERSION#v}_${OS_TYPE}_${ARCH}.tar.gz" diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 076f2532b..520a83e36 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 93578b1ae..4a5454002 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -463,6 +463,9 @@ components: description: (Cloud only) Indicates whether peer needs approval type: boolean example: true + disapproval_reason: + description: (Cloud only) Reason why the peer requires approval + type: string country_code: $ref: '#/components/schemas/CountryCode' city_name: diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index 3dbb32ef6..9611d26d6 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -1037,6 +1037,9 @@ type Peer struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` @@ -1124,6 +1127,9 @@ type PeerBatch struct { // CreatedAt Peer creation date (UTC) CreatedAt time.Time `json:"created_at"` + // DisapprovalReason (Cloud only) Reason why the peer requires approval + DisapprovalReason *string `json:"disapproval_reason,omitempty"` + // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud DnsLabel string `json:"dns_label"` diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index 11107e5de..16737cf58 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -428,7 +428,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 31f3372c0..5368b57a2 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 96873dee7..bf8f8e327 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -94,7 +94,7 @@ var ( startPprof() - opts, certManager, err := getTLSConfigurations() + opts, certManager, tlsConfig, err := getTLSConfigurations() if err != nil { return err } @@ -132,7 +132,7 @@ var ( // Start the main server - always serve HTTP with WebSocket proxy support // If certManager is configured and signalPort == 443, it's already handled by startServerWithCertManager - if certManager == nil { + if tlsConfig == nil { // Without TLS, serve plain HTTP httpListener, err = net.Listen("tcp", fmt.Sprintf(":%d", signalPort)) if err != nil { @@ -140,9 +140,10 @@ var ( } log.Infof("running HTTP server with WebSocket proxy (no TLS): %s", httpListener.Addr().String()) serveHTTP(httpListener, grpcRootHandler) - } else if signalPort != 443 { - // With TLS but not on port 443, serve HTTPS - httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), certManager.TLSConfig()) + } else if certManager == nil || signalPort != 443 { + // Serve HTTPS if not already handled by startServerWithCertManager + // (custom certificates or Let's Encrypt with custom port) + httpListener, err = tls.Listen("tcp", fmt.Sprintf(":%d", signalPort), tlsConfig) if err != nil { return err } @@ -202,7 +203,7 @@ func startPprof() { }() } -func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { +func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, *tls.Config, error) { var ( err error certManager *autocert.Manager @@ -211,33 +212,33 @@ func getTLSConfigurations() ([]grpc.ServerOption, *autocert.Manager, error) { if signalLetsencryptDomain == "" && signalCertFile == "" && signalCertKey == "" { log.Infof("running without TLS") - return nil, nil, nil + return nil, nil, nil, nil } if signalLetsencryptDomain != "" { certManager, err = encryption.CreateCertManager(signalSSLDir, signalLetsencryptDomain) if err != nil { - return nil, certManager, err + return nil, certManager, nil, err } tlsConfig = certManager.TLSConfig() log.Infof("setting up TLS with LetsEncrypt.") } else { if signalCertFile == "" || signalCertKey == "" { log.Errorf("both cert-file and cert-key must be provided when not using LetsEncrypt") - return nil, certManager, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") + return nil, certManager, nil, errors.New("both cert-file and cert-key must be provided when not using LetsEncrypt") } tlsConfig, err = loadTLSConfig(signalCertFile, signalCertKey) if err != nil { log.Errorf("cannot load TLS credentials: %v", err) - return nil, certManager, err + return nil, certManager, nil, err } log.Infof("setting up TLS with custom certificates.") } transportCredentials := credentials.NewTLS(tlsConfig) - return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, err + return []grpc.ServerOption{grpc.Creds(transportCredentials)}, certManager, tlsConfig, err } func startServerWithCertManager(certManager *autocert.Manager, grpcRootHandler http.Handler) {