diff --git a/client/android/login.go b/client/android/login.go index 0df78dbc3..16df24ba8 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error { } func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "") if err != nil { return nil, err } diff --git a/client/cmd/debug.go b/client/cmd/debug.go index d53c5f06b..430012a17 100644 --- a/client/cmd/debug.go +++ b/client/cmd/debug.go @@ -168,7 +168,7 @@ func runForDuration(cmd *cobra.Command, args []string) error { client := proto.NewDaemonServiceClient(conn) - stat, err := client.Status(cmd.Context(), &proto.StatusRequest{}) + stat, err := client.Status(cmd.Context(), &proto.StatusRequest{ShouldRunProbes: true}) if err != nil { return fmt.Errorf("failed to get status: %v", status.Convert(err).Message()) } @@ -303,7 +303,7 @@ func setSyncResponsePersistence(cmd *cobra.Command, args []string) error { func getStatusOutput(cmd *cobra.Command, anon bool) string { var statusOutputString string - statusResp, err := getStatus(cmd.Context()) + statusResp, err := getStatus(cmd.Context(), true) if err != nil { cmd.PrintErrf("Failed to get status: %v\n", err) } else { diff --git a/client/cmd/login.go b/client/cmd/login.go index 40b55f858..b0c877faa 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str Username: &username, } + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { loginRequest.OptionalPreSharedKey = &preSharedKey } @@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, return fmt.Errorf("read config file %s: %v", configFilePath, err) } - err = foregroundLogin(ctx, cmd, config, setupKey) + err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo return nil } -func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { +func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { needsLogin := false err := WithBackOff(func() error { @@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken := "" if setupKey == "" && needsLogin { - tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config) + tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman return nil } -func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) +func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) { + hint := "" + pm := profilemanager.NewProfileManager() + profileState, err := pm.GetProfileState(profileName) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + hint = profileState.Email + } + + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint) if err != nil { return nil, err } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index 075ead44e..f6828d96a 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -10,6 +10,8 @@ import ( "path/filepath" "runtime" + log "github.com/sirupsen/logrus" + "github.com/kardianos/service" "github.com/spf13/cobra" @@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error { svcConfig.Option["LogDirectory"] = dir } } + + if err := configureSystemdNetworkd(); err != nil { + log.Warnf("failed to configure systemd-networkd: %v", err) + } } if runtime.GOOS == "windows" { @@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{ return fmt.Errorf("uninstall service: %w", err) } + if runtime.GOOS == "linux" { + if err := cleanupSystemdNetworkd(); err != nil { + log.Warnf("failed to cleanup systemd-networkd configuration: %v", err) + } + } + cmd.Println("NetBird service has been uninstalled") return nil }, @@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) { return status == service.StatusRunning, nil } + +const ( + networkdConf = "/etc/systemd/networkd.conf" + networkdConfDir = "/etc/systemd/networkd.conf.d" + networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf" + networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing +# routes and policy rules managed by NetBird. + +[Network] +ManageForeignRoutes=no +ManageForeignRoutingPolicyRules=no +` +) + +// configureSystemdNetworkd creates a drop-in configuration file to prevent +// systemd-networkd from removing NetBird's routes and policy rules. +func configureSystemdNetworkd() error { + if _, err := os.Stat(networkdConf); os.IsNotExist(err) { + log.Debug("systemd-networkd not in use, skipping configuration") + return nil + } + + // nolint:gosec // standard networkd permissions + if err := os.MkdirAll(networkdConfDir, 0755); err != nil { + return fmt.Errorf("create networkd.conf.d directory: %w", err) + } + + // nolint:gosec // standard networkd permissions + if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil { + return fmt.Errorf("write networkd configuration: %w", err) + } + + return nil +} + +// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file. +func cleanupSystemdNetworkd() error { + if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) { + return nil + } + + if err := os.Remove(networkdConfFile); err != nil { + return fmt.Errorf("remove networkd configuration: %w", err) + } + + return nil +} diff --git a/client/cmd/status.go b/client/cmd/status.go index 723f2367c..6e57ceb89 100644 --- a/client/cmd/status.go +++ b/client/cmd/status.go @@ -68,7 +68,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { ctx := internal.CtxInitState(cmd.Context()) - resp, err := getStatus(ctx) + resp, err := getStatus(ctx, false) if err != nil { return err } @@ -121,7 +121,7 @@ func statusFunc(cmd *cobra.Command, args []string) error { return nil } -func getStatus(ctx context.Context) (*proto.StatusResponse, error) { +func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) { conn, err := DialClientGRPCServer(ctx, daemonAddr) if err != nil { return nil, fmt.Errorf("failed to connect to daemon error: %v\n"+ @@ -130,7 +130,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) { } defer conn.Close() - resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true}) + resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes}) if err != nil { return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message()) } diff --git a/client/cmd/up.go b/client/cmd/up.go index d047c041e..80175f7be 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) - err = foregroundLogin(ctx, cmd, config, providedSetupKey) + err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ loginRequest.ProfileName = &activeProf.Name loginRequest.Username = &username + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/firewall/create.go b/client/firewall/create.go index 7b265e1d1..24f12bc6d 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -15,13 +15,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index aa2f0d4d1..12dcaee8a 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu) if !iface.IsUserspaceBind() { return fm, err @@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger) + return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { - fm, err := createFW(iface) +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) { + fm, err := createFW(iface, mtu) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) } @@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, return fm, nil } -func createFW(iface IFaceMapper) (firewall.Manager, error) { +func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) { switch check() { case IPTABLES: log.Info("creating an iptables firewall manager") - return nbiptables.Create(iface) + return nbiptables.Create(iface, mtu) case NFTABLES: log.Info("creating an nftables firewall manager") - return nbnftables.Create(iface) + return nbnftables.Create(iface, mtu) default: log.Info("no firewall manager found, trying to use userspace packet filtering firewall") return nil, errors.New("no firewall manager found") } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu) } else { - fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu) } if errUsp != nil { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 81f7a9125..2563a9052 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -36,7 +36,7 @@ type iFaceMapper interface { } // Create iptables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, fmt.Errorf("init iptables: %w", err) @@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) { ipv4Client: iptablesClient, } - m.router, err = newRouter(iptablesClient, wgIface) + m.router, err = newRouter(iptablesClient, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, } stateManager.RegisterState(state) @@ -260,6 +261,22 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/manager_linux_test.go b/client/firewall/iptables/manager_linux_test.go index a5cc62feb..6b5401e2b 100644 --- a/client/firewall/iptables/manager_linux_test.go +++ b/client/firewall/iptables/manager_linux_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/require" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) { require.NoError(t, err) // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) { ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) { } // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second) diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index 081991235..305b0bf28 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -30,17 +30,20 @@ const ( chainPOSTROUTING = "POSTROUTING" chainPREROUTING = "PREROUTING" + chainFORWARD = "FORWARD" chainRTNAT = "NETBIRD-RT-NAT" chainRTFWDIN = "NETBIRD-RT-FWD-IN" chainRTFWDOUT = "NETBIRD-RT-FWD-OUT" chainRTPRE = "NETBIRD-RT-PRE" chainRTRDR = "NETBIRD-RT-RDR" + chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP" routingFinalForwardJump = "ACCEPT" routingFinalNatJump = "MASQUERADE" jumpManglePre = "jump-mangle-pre" jumpNatPre = "jump-nat-pre" jumpNatPost = "jump-nat-post" + jumpMSSClamp = "jump-mss-clamp" markManglePre = "mark-mangle-pre" markManglePost = "mark-mangle-post" matchSet = "--match-set" @@ -48,6 +51,9 @@ const ( dnatSuffix = "_dnat" snatSuffix = "_snat" fwdSuffix = "_fwd" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) type ruleInfo struct { @@ -77,16 +83,18 @@ type router struct { ipsetCounter *ipsetCounter wgIface iFaceMapper legacyManagement bool + mtu uint16 stateManager *statemanager.Manager ipFwdState *ipfwdstate.IPForwardingState } -func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) { +func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ iptablesClient: iptablesClient, rules: make(map[string][]string), wgIface: wgIface, + mtu: mtu, ipFwdState: ipfwdstate.NewIPForwardingState(), } @@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain) if err != nil { @@ -416,6 +425,7 @@ func (r *router) createContainers() error { {chainRTPRE, tableMangle}, {chainRTNAT, tableNat}, {chainRTRDR, tableNat}, + {chainRTMSSCLAMP, tableMangle}, } { if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil { return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err) @@ -438,6 +448,10 @@ func (r *router) createContainers() error { return fmt.Errorf("add jump rules: %w", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + return nil } @@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + // Add jump rule from FORWARD chain in mangle table to our custom chain + jumpRule := []string{ + "-j", chainRTMSSCLAMP, + } + if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil { + return fmt.Errorf("add jump to MSS clamp chain: %w", err) + } + r.rules[jumpMSSClamp] = jumpRule + + ruleOut := []string{ + "-o", r.wgIface.Name(), + "-p", "tcp", + "--tcp-flags", "SYN,RST", "SYN", + "-j", "TCPMSS", + "--set-mss", fmt.Sprintf("%d", mss), + } + if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil { + return fmt.Errorf("add outbound MSS clamp rule: %w", err) + } + r.rules["mss-clamp-out"] = ruleOut + + return nil +} + func (r *router) insertEstablishedRule(chain string) error { establishedRule := getConntrackEstablished() @@ -558,7 +601,7 @@ func (r *router) addJumpRules() error { } func (r *router) cleanJumpRules() error { - for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} { + for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} { if rule, exists := r.rules[ruleKey]; exists { var table, chain string switch ruleKey { @@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error { case jumpNatPre: table = tableNat chain = chainPREROUTING + case jumpMSSClamp: + table = tableMangle + chain = chainFORWARD default: return fmt.Errorf("unknown jump rule: %s", ruleKey) } @@ -880,6 +926,54 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nberrors.FormatErrorOrNil(merr) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + dnatRule := []string{ + "-i", r.wgIface.Name(), + "-p", strings.ToLower(string(protocol)), + "--dport", strconv.Itoa(int(sourcePort)), + "-d", localAddr.String(), + "-m", "addrtype", "--dst-type", "LOCAL", + "-j", "DNAT", + "--to-destination", ":" + strconv.Itoa(int(targetPort)), + } + + ruleInfo := ruleInfo{ + table: tableNat, + chain: chainRTRDR, + rule: dnatRule, + } + + if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + r.rules[ruleID] = ruleInfo.rule + + r.updateState() + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if dnatRule, exists := r.rules[ruleID]; exists { + if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil { + return fmt.Errorf("delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + r.updateState() + return nil +} + func applyPort(flag string, port *firewall.Port) []string { if port == nil { return nil diff --git a/client/firewall/iptables/router_linux_test.go b/client/firewall/iptables/router_linux_test.go index 3490c5dad..6707573be 100644 --- a/client/firewall/iptables/router_linux_test.go +++ b/client/firewall/iptables/router_linux_test.go @@ -14,6 +14,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" nbnet "github.com/netbirdio/netbird/client/net" ) @@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "should return a valid iptables manager") require.NoError(t, manager.init(nil)) @@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { assert.NoError(t, manager.Reset(), "shouldn't return error") }() - // Now 5 rules: // 1. established rule forward in // 2. estbalished rule forward out // 3. jump rule to POST nat chain @@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) { // 7. static return masquerade rule // 8. mangle prerouting mark rule // 9. mangle postrouting mark rule - require.Len(t, manager.rules, 9, "should have created rules map") + // 10. jump rule to MSS clamping chain + // 11. MSS clamping rule for outbound traffic + require.Len(t, manager.rules, 11, "should have created rules map") exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT) require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING) @@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "failed to init iptables client") - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) @@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) { t.Run(testCase.Name, func(t *testing.T) { iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4) - manager, err := newRouter(iptablesClient, ifaceMock) + manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "shouldn't return error") require.NoError(t, manager.init(nil)) defer func() { @@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) require.NoError(t, err, "Failed to create iptables client") - r, err := newRouter(iptablesClient, ifaceMock) + r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router manager") require.NoError(t, r.init(nil)) diff --git a/client/firewall/iptables/state_linux.go b/client/firewall/iptables/state_linux.go index 6ef159e01..c88774c1f 100644 --- a/client/firewall/iptables/state_linux.go +++ b/client/firewall/iptables/state_linux.go @@ -4,6 +4,7 @@ import ( "fmt" "sync" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -11,6 +12,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - ipt, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + ipt, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create iptables manager: %w", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index 3b3164823..72e6a5c68 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -100,6 +100,9 @@ type Manager interface { // // If comment argument is empty firewall manager should set // rule ID as comment for the rule + // + // Note: Callers should call Flush() after adding rules to ensure + // they are applied to the kernel and rule handles are refreshed. AddPeerFiltering( id []byte, ip net.IP, @@ -151,14 +154,20 @@ type Manager interface { DisableRouting() error - // AddDNATRule adds a DNAT rule + // AddDNATRule adds outbound DNAT rule for forwarding external traffic to the NetBird network. AddDNATRule(ForwardRule) (Rule, error) - // DeleteDNATRule deletes a DNAT rule + // DeleteDNATRule deletes the outbound DNAT rule. DeleteDNATRule(Rule) error // UpdateSet updates the set with the given prefixes UpdateSet(hash Set, prefixes []netip.Prefix) error + + // AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services + AddInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error + + // RemoveInboundDNAT removes inbound DNAT rule + RemoveInboundDNAT(localAddr netip.Addr, protocol Protocol, sourcePort, targetPort uint16) error } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/acl_linux.go b/client/firewall/nftables/acl_linux.go index 9ff5b8c92..a9d066e2f 100644 --- a/client/firewall/nftables/acl_linux.go +++ b/client/firewall/nftables/acl_linux.go @@ -29,8 +29,6 @@ const ( chainNameForwardFilter = "netbird-acl-forward-filter" chainNameManglePrerouting = "netbird-mangle-prerouting" chainNameManglePostrouting = "netbird-mangle-postrouting" - - allowNetbirdInputRuleID = "allow Netbird incoming traffic" ) const flushError = "flush: %w" @@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error { // createDefaultAllowRules creates default allow rules for the input and output chains func (m *AclManager) createDefaultAllowRules() error { expIn := []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: 12, - Len: 4, - }, - // mask - &expr.Bitwise{ - SourceRegister: 1, - DestRegister: 1, - Len: 4, - Mask: []byte{0, 0, 0, 0}, - Xor: []byte{0, 0, 0, 0}, - }, - // net address - &expr.Cmp{ - Register: 1, - Data: []byte{0, 0, 0, 0}, - }, &expr.Verdict{ Kind: expr.VerdictAccept, }, @@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering( action firewall.Action, ipset *nftables.Set, ) (*Rule, error) { - ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset) + ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset) if r, ok := m.rules[ruleId]; ok { return &Rule{ nftRule: r.nftRule, @@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering( } if err := m.rConn.Flush(); err != nil { - return nil, fmt.Errorf(flushError, err) + return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err) } ruleStruct := &Rule{ - nftRule: nftRule, + nftRule: nftRule, + // best effort mangle rule mangleRule: m.createPreroutingRule(expressions, userData), nftSet: ipset, ruleID: ruleId, @@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt }, ) - return m.rConn.AddRule(&nftables.Rule{ + nfRule := m.rConn.AddRule(&nftables.Rule{ Table: m.workTable, Chain: m.chainPrerouting, Exprs: preroutingExprs, UserData: userData, }) + + if err := m.rConn.Flush(); err != nil { + log.Errorf("failed to flush mangle rule %s: %v", string(userData), err) + return nil + } + + return nfRule } func (m *AclManager) createDefaultChains() (err error) { @@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro return nil } -func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { - rulesetID := ":" +func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string { + rulesetID := ":" + string(proto) + ":" if sPort != nil { rulesetID += sPort.String() } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 560f224f5..bd19f1067 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -1,11 +1,11 @@ package nftables import ( - "bytes" "context" "fmt" "net" "net/netip" + "os" "sync" "github.com/google/nftables" @@ -19,13 +19,22 @@ import ( ) const ( - // tableNameNetbird is the name of the table that is used for filtering by the Netbird client + // tableNameNetbird is the default name of the table that is used for filtering by the Netbird client tableNameNetbird = "netbird" + // envTableName is the environment variable to override the table name + envTableName = "NB_NFTABLES_TABLE" tableNameFilter = "filter" chainNameInput = "INPUT" ) +func getTableName() string { + if name := os.Getenv(envTableName); name != "" { + return name + } + return tableNameNetbird +} + // iFaceMapper defines subset methods of interface required for manager type iFaceMapper interface { Name() string @@ -44,16 +53,16 @@ type Manager struct { } // Create nftables firewall manager -func Create(wgIface iFaceMapper) (*Manager, error) { +func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) { m := &Manager{ rConn: &nftables.Conn{}, wgIface: wgIface, } - workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4} + workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4} var err error - m.router, err = newRouter(workTable, wgIface) + m.router, err = newRouter(workTable, wgIface, mtu) if err != nil { return nil, fmt.Errorf("create router: %w", err) } @@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { NameStr: m.wgIface.Name(), WGAddress: m.wgIface.Address(), UserspaceBind: m.wgIface.IsUserspaceBind(), + MTU: m.router.mtu, }, }); err != nil { log.Errorf("failed to update state: %v", err) @@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error { m.mutex.Lock() defer m.mutex.Unlock() - err := m.aclManager.createDefaultAllowRules() - if err != nil { - return fmt.Errorf("failed to create default allow rules: %v", err) + if err := m.aclManager.createDefaultAllowRules(); err != nil { + return fmt.Errorf("create default allow rules: %w", err) } - - chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - return fmt.Errorf("list of chains: %w", err) - } - - var chain *nftables.Chain - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - chain = c - break - } - } - - if chain == nil { - log.Debugf("chain INPUT not found. Skipping add allow netbird rule") - return nil - } - - rules, err := m.rConn.GetRules(chain.Table, chain) - if err != nil { - return fmt.Errorf("failed to get rules for the INPUT chain: %v", err) - } - - if rule := m.detectAllowNetbirdRule(rules); rule != nil { - log.Debugf("allow netbird rule already exists: %v", rule) - return nil - } - - m.applyAllowNetbirdRules(chain) - - err = m.rConn.Flush() - if err != nil { - return fmt.Errorf("failed to flush allow input netbird rules: %v", err) + if err := m.rConn.Flush(); err != nil { + return fmt.Errorf("flush allow input netbird rules: %w", err) } return nil @@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { m.mutex.Lock() defer m.mutex.Unlock() - if err := m.resetNetbirdInputRules(); err != nil { - return fmt.Errorf("reset netbird input rules: %v", err) - } - if err := m.router.Reset(); err != nil { return fmt.Errorf("reset router: %v", err) } @@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { return nil } -func (m *Manager) resetNetbirdInputRules() error { - chains, err := m.rConn.ListChains() - if err != nil { - return fmt.Errorf("list chains: %w", err) - } - - m.deleteNetbirdInputRules(chains) - - return nil -} - -func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) { - for _, c := range chains { - if c.Table.Name == tableNameFilter && c.Name == chainNameInput { - rules, err := m.rConn.GetRules(c.Table, c) - if err != nil { - log.Errorf("get rules for chain %q: %v", c.Name, err) - continue - } - - m.deleteMatchingRules(rules) - } - } -} - -func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) { - for _, r := range rules { - if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) { - if err := m.rConn.DelRule(r); err != nil { - log.Errorf("delete rule: %v", err) - } - } - } -} - func (m *Manager) cleanupNetbirdTables() error { tables, err := m.rConn.ListTables() if err != nil { return fmt.Errorf("list tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } @@ -376,61 +315,40 @@ func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return m.router.UpdateSet(set, prefixes) } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.AddInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + return m.router.RemoveInboundDNAT(localAddr, protocol, sourcePort, targetPort) +} + func (m *Manager) createWorkTable() (*nftables.Table, error) { tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4) if err != nil { return nil, fmt.Errorf("list of tables: %w", err) } + tableName := getTableName() for _, t := range tables { - if t.Name == tableNameNetbird { + if t.Name == tableName { m.rConn.DelTable(t) } } - table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}) + table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}) err = m.rConn.Flush() return table, err } -func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) { - rule := &nftables.Rule{ - Table: chain.Table, - Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: ifname(m.wgIface.Name()), - }, - &expr.Verdict{ - Kind: expr.VerdictAccept, - }, - }, - UserData: []byte(allowNetbirdInputRuleID), - } - _ = m.rConn.InsertRule(rule) -} - -func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule { - ifName := ifname(m.wgIface.Name()) - for _, rule := range existedRules { - if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput { - if len(rule.Exprs) < 4 { - if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME { - continue - } - if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) { - continue - } - return rule - } - } - } - return nil -} - func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { rule := &nftables.Rule{ Table: table, diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index c7f05dcb7..adec802c8 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -16,6 +16,7 @@ import ( "golang.org/x/sys/unix" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false } func TestNftablesManager(t *testing.T) { // just check on the local interface - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) { func TestNftablesManagerRuleOrder(t *testing.T) { // This test verifies rule insertion order in nftables peer ACLs // We add accept rule first, then deny rule to test ordering behavior - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) @@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) { for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} { t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) { // just check on the local interface - manager, err := Create(mock) + manager, err := Create(mock, iface.DefaultMTU) require.NoError(t, err) require.NoError(t, manager.Init(nil)) time.Sleep(time.Second * 3) @@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr := runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) require.NoError(t, err, "failed to create manager") require.NoError(t, manager.Init(nil)) diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index e918d0524..6192c92aa 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -16,6 +16,7 @@ import ( "github.com/google/nftables/xt" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -32,12 +33,17 @@ const ( chainNameRoutingNat = "netbird-rt-postrouting" chainNameRoutingRdr = "netbird-rt-redirect" chainNameForward = "FORWARD" + chainNameMangleForward = "netbird-mangle-forward" userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" + userDataAcceptInputRule = "inputaccept" dnatSuffix = "_dnat" snatSuffix = "_snat" + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 ) const refreshRulesMapError = "refresh rules map: %w" @@ -63,9 +69,10 @@ type router struct { wgIface iFaceMapper ipFwdState *ipfwdstate.IPForwardingState legacyManagement bool + mtu uint16 } -func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) { +func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) { r := &router{ conn: &nftables.Conn{}, workTable: workTable, @@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) rules: make(map[string]*nftables.Rule), wgIface: wgIface, ipFwdState: ipfwdstate.NewIPForwardingState(), + mtu: mtu, } r.ipsetCounter = refcounter.New( @@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) func (r *router) init(workTable *nftables.Table) error { r.workTable = workTable - if err := r.removeAcceptForwardRules(); err != nil { - log.Errorf("failed to clean up rules from FORWARD chain: %s", err) + if err := r.removeAcceptFilterRules(); err != nil { + log.Errorf("failed to clean up rules from filter table: %s", err) } if err := r.createContainers(); err != nil { @@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error { return nil } -// Reset cleans existing nftables default forward rules from the system +// Reset cleans existing nftables filter table rules from the system func (r *router) Reset() error { // clear without deleting the ipsets, the nf table will be deleted by the caller r.ipsetCounter.Clear() var merr *multierror.Error - if err := r.removeAcceptForwardRules(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err)) + if err := r.removeAcceptFilterRules(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } if err := r.removeNatPreroutingRules(); err != nil { @@ -220,11 +228,23 @@ func (r *router) createContainers() error { Type: nftables.ChainTypeFilter, }) + r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{ + Name: chainNameMangleForward, + Table: r.workTable, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityMangle, + Type: nftables.ChainTypeFilter, + }) + // Add the single NAT rule that matches on mark if err := r.addPostroutingRules(); err != nil { return fmt.Errorf("add single nat rule: %v", err) } + if err := r.addMSSClampingRules(); err != nil { + log.Errorf("failed to add MSS clamping rules: %s", err) + } + if err := r.acceptForwardRules(); err != nil { log.Errorf("failed to add accept rules for the forward chain: %s", err) } @@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error { return nil } +// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic. +// TODO: Add IPv6 support +func (r *router) addMSSClampingRules() error { + mss := r.mtu - ipTCPHeaderMinSize + + exprsOut := []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyOIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{ + Key: expr.MetaKeyL4PROTO, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte{unix.IPPROTO_TCP}, + }, + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseTransportHeader, + Offset: 13, + Len: 1, + }, + &expr.Bitwise{ + DestRegister: 1, + SourceRegister: 1, + Len: 1, + Mask: []byte{0x02}, + Xor: []byte{0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0x00}, + }, + &expr.Counter{}, + &expr.Exthdr{ + DestRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + &expr.Cmp{ + Op: expr.CmpOpGt, + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Immediate{ + Register: 1, + Data: binaryutil.BigEndian.PutUint16(uint16(mss)), + }, + &expr.Exthdr{ + SourceRegister: 1, + Type: 2, + Offset: 2, + Len: 2, + Op: expr.ExthdrOpTcpopt, + }, + } + + r.conn.AddRule(&nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameMangleForward], + Exprs: exprsOut, + }) + + return nil +} + // addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error { sourceExp, err := r.applyNetwork(pair.Source, nil, true) @@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error { // that our traffic is not dropped by existing rules there. // The existing FORWARD rules/policies decide outbound traffic towards our interface. // In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule. +// This method also adds INPUT chain rules to allow traffic to the local interface. func (r *router) acceptForwardRules() error { if r.filterTable == nil { log.Debugf("table 'filter' not found for forward rules, skipping accept rules") @@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error { fw := "iptables" defer func() { - log.Debugf("Used %s to add accept forward rules", fw) + log.Debugf("Used %s to add accept forward and input rules", fw) }() // Try iptables first and fallback to nftables if iptables is not available @@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) fw = "nftables" - return r.acceptForwardRulesNftables() + return r.acceptFilterRulesNftables() } - return r.acceptForwardRulesIptables(ipt) + return r.acceptFilterRulesIptables(ipt) } -func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err)) } else { - log.Debugf("added iptables rule: %v", rule) + log.Debugf("added iptables forward rule: %v", rule) } } + inputRule := r.getAcceptInputRule() + if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err)) + } else { + log.Debugf("added iptables input rule: %v", inputRule) + } + return nberrors.FormatErrorOrNil(merr) } @@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string { } } -func (r *router) acceptForwardRulesNftables() error { +func (r *router) getAcceptInputRule() []string { + return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"} +} + +func (r *router) acceptFilterRulesNftables() error { intf := ifname(r.wgIface.Name()) - // Rule for incoming interface (iif) with counter iifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ @@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error { }, } - // Rule for outgoing interface (oif) with counter oifRule := &nftables.Rule{ Table: r.filterTable, Chain: &nftables.Chain{ - Name: "FORWARD", + Name: chainNameForward, Table: r.filterTable, Type: nftables.ChainTypeFilter, Hooknum: nftables.ChainHookForward, @@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error { Exprs: append(oifExprs, getEstablishedExprs(2)...), UserData: []byte(userDataAcceptForwardRuleOif), } - r.conn.InsertRule(oifRule) + inputRule := &nftables.Rule{ + Table: r.filterTable, + Chain: &nftables.Chain{ + Name: chainNameInput, + Table: r.filterTable, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: intf, + }, + &expr.Counter{}, + &expr.Verdict{Kind: expr.VerdictAccept}, + }, + UserData: []byte(userDataAcceptInputRule), + } + r.conn.InsertRule(inputRule) + return nil } -func (r *router) removeAcceptForwardRules() error { +func (r *router) removeAcceptFilterRules() error { if r.filterTable == nil { return nil } - // Try iptables first and fallback to nftables if iptables is not available ipt, err := iptables.New() if err != nil { log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err) - return r.removeAcceptForwardRulesNftables() + return r.removeAcceptFilterRulesNftables() } - return r.removeAcceptForwardRulesIptables(ipt) + return r.removeAcceptFilterRulesIptables(ipt) } -func (r *router) removeAcceptForwardRulesNftables() error { +func (r *router) removeAcceptFilterRulesNftables() error { chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) if err != nil { return fmt.Errorf("list chains: %v", err) } for _, chain := range chains { - if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward { + if chain.Table.Name != r.filterTable.Name { + continue + } + + if chain.Name != chainNameForward && chain.Name != chainNameInput { continue } @@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error { for _, rule := range rules { if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) || - bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) { + bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) || + bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) { if err := r.conn.DelRule(rule); err != nil { return fmt.Errorf("delete rule: %v", err) } @@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error { return nil } -func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error { +func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error { var merr *multierror.Error + for _, rule := range r.getAcceptForwardRules() { if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil { - merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err)) + merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err)) } } + inputRule := r.getAcceptInputRule() + if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil { + merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err)) + } + return nberrors.FormatErrorOrNil(merr) } @@ -1350,6 +1490,103 @@ func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error { return nil } +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (r *router) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if _, exists := r.rules[ruleID]; exists { + return nil + } + + protoNum, err := protoToInt(protocol) + if err != nil { + return fmt.Errorf("convert protocol to number: %w", err) + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: ifname(r.wgIface.Name()), + }, + &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 2}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 2, + Data: []byte{protoNum}, + }, + &expr.Payload{ + DestRegister: 3, + Base: expr.PayloadBaseTransportHeader, + Offset: 2, + Len: 2, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 3, + Data: binaryutil.BigEndian.PutUint16(sourcePort), + }, + } + + exprs = append(exprs, applyPrefix(netip.PrefixFrom(localAddr, 32), false)...) + + exprs = append(exprs, + &expr.Immediate{ + Register: 1, + Data: localAddr.AsSlice(), + }, + &expr.Immediate{ + Register: 2, + Data: binaryutil.BigEndian.PutUint16(targetPort), + }, + &expr.NAT{ + Type: expr.NATTypeDestNAT, + Family: uint32(nftables.TableFamilyIPv4), + RegAddrMin: 1, + RegProtoMin: 2, + RegProtoMax: 0, + }, + ) + + dnatRule := &nftables.Rule{ + Table: r.workTable, + Chain: r.chains[chainNameRoutingRdr], + Exprs: exprs, + UserData: []byte(ruleID), + } + r.conn.AddRule(dnatRule) + + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("add inbound DNAT rule: %w", err) + } + + r.rules[ruleID] = dnatRule + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + if err := r.refreshRulesMap(); err != nil { + return fmt.Errorf(refreshRulesMapError, err) + } + + ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort) + + if rule, exists := r.rules[ruleID]; exists { + if err := r.conn.DelRule(rule); err != nil { + return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err) + } + if err := r.conn.Flush(); err != nil { + return fmt.Errorf("flush delete inbound DNAT rule: %w", err) + } + delete(r.rules, ruleID) + } + + return nil +} + // applyNetwork generates nftables expressions for networks (CIDR) or sets func (r *router) applyNetwork( network firewall.Network, diff --git a/client/firewall/nftables/router_linux_test.go b/client/firewall/nftables/router_linux_test.go index 4fdbf3505..3531b014b 100644 --- a/client/firewall/nftables/router_linux_test.go +++ b/client/firewall/nftables/router_linux_test.go @@ -17,6 +17,7 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/test" + "github.com/netbirdio/netbird/client/iface" ) const ( @@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) { for _, testCase := range test.InsertRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { // need fw manager to init both acl mgr and router for all chains to be present - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) { for _, testCase := range test.RemoveRuleTestCases { t.Run(testCase.Name, func(t *testing.T) { - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, iface.DefaultMTU) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) }) @@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) @@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) { defer deleteWorkTable() - r, err := newRouter(workTable, ifaceMock) + r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU) require.NoError(t, err, "Failed to create router") require.NoError(t, r.init(workTable)) diff --git a/client/firewall/nftables/state_linux.go b/client/firewall/nftables/state_linux.go index f805623d6..48b7b3741 100644 --- a/client/firewall/nftables/state_linux.go +++ b/client/firewall/nftables/state_linux.go @@ -3,6 +3,7 @@ package nftables import ( "fmt" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -10,6 +11,7 @@ type InterfaceState struct { NameStr string `json:"name"` WGAddress wgaddr.Address `json:"wg_address"` UserspaceBind bool `json:"userspace_bind"` + MTU uint16 `json:"mtu"` } func (i *InterfaceState) Name() string { @@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - nft, err := Create(s.InterfaceState) + mtu := s.InterfaceState.MTU + if mtu == 0 { + mtu = iface.DefaultMTU + } + nft, err := Create(s.InterfaceState, mtu) if err != nil { return fmt.Errorf("create nftables manager: %w", err) } diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index bcf6d894b..7be0dd78f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -22,6 +22,8 @@ type BaseConnTrack struct { PacketsRx atomic.Uint64 BytesTx atomic.Uint64 BytesRx atomic.Uint64 + + DNATOrigPort atomic.Uint32 } // these small methods will be inlined by the compiler diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a2355e5c7..8d64412e0 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -157,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -171,28 +171,30 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui if exists { t.updateState(key, conn, flags, direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } -// TrackOutbound records an outbound TCP connection -func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size) +// TrackOutbound records an outbound TCP connection and returns the original port if DNAT reversal is needed +func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) uint16 { + if _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound processes an inbound TCP packet and updates connection state -func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size) +func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size, dnatOrigPort) } // track is the common implementation for tracking both inbound and outbound connections -func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) +func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size) if exists || flags&TCPSyn == 0 { return } @@ -210,8 +212,13 @@ func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, fla conn.tombstone.Store(false) conn.state.Store(int32(TCPStateNew)) + conn.DNATOrigPort.Store(uint32(origPort)) - t.logger.Trace2("New %s TCP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s TCP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s TCP connection: %s", direction, key) + } t.updateState(key, conn, flags, direction, size) t.mutex.Lock() @@ -449,6 +456,21 @@ func (t *TCPTracker) cleanup() { } } +// GetConnection safely retrieves a connection state +func (t *TCPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*TCPConnTrack, bool) { + t.mutex.RLock() + defer t.mutex.RUnlock() + + key := ConnKey{ + SrcIP: srcIP, + DstIP: dstIP, + SrcPort: srcPort, + DstPort: dstPort, + } + conn, exists := t.connections[key] + return conn, exists +} + // Close stops the cleanup routine and releases resources func (t *TCPTracker) Close() { t.tickerCancel() diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index d01a8db4f..bb440f70a 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -603,7 +603,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { serverPort := uint16(80) // 1. Client sends SYN (we receive it as inbound) - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0) key := ConnKey{ SrcIP: clientIP, @@ -623,12 +623,12 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100) // 3. Client sends ACK to complete handshake - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion") // 4. Test data transfer // Client sends data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000, 0) // Server sends ACK for data tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100) @@ -637,7 +637,7 @@ func TestTCPInboundInitiatedConnection(t *testing.T) { tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500) // Client sends ACK for data - tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100) + tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0) // Verify state and counters require.Equal(t, TCPStateEstablished, conn.GetState()) diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index e7f49c46f..a3b6a418b 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -58,20 +58,23 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp return tracker } -// TrackOutbound records an outbound UDP connection -func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) { - if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists { - // if (inverted direction) conn is not tracked, track this direction - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size) +// TrackOutbound records an outbound UDP connection and returns the original port if DNAT reversal is needed +func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) uint16 { + _, origPort, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size) + if exists { + return origPort } + // if (inverted direction) conn is not tracked, track this direction + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size, 0) + return 0 } // TrackInbound records an inbound UDP connection -func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) { - t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size) +func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int, dnatOrigPort uint16) { + t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size, dnatOrigPort) } -func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) { +func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, uint16, bool) { key := ConnKey{ SrcIP: srcIP, DstIP: dstIP, @@ -86,15 +89,15 @@ func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort if exists { conn.UpdateLastSeen() conn.UpdateCounters(direction, size) - return key, true + return key, uint16(conn.DNATOrigPort.Load()), true } - return key, false + return key, 0, false } // track is the common implementation for tracking both inbound and outbound connections -func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) { - key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) +func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int, origPort uint16) { + key, _, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size) if exists { return } @@ -109,6 +112,7 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d SourcePort: srcPort, DestPort: dstPort, } + conn.DNATOrigPort.Store(uint32(origPort)) conn.UpdateLastSeen() conn.UpdateCounters(direction, size) @@ -116,7 +120,11 @@ func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d t.connections[key] = conn t.mutex.Unlock() - t.logger.Trace2("New %s UDP connection: %s", direction, key) + if origPort != 0 { + t.logger.Trace4("New %s UDP connection: %s (port DNAT %d -> %d)", direction, key, origPort, dstPort) + } else { + t.logger.Trace2("New %s UDP connection: %s", direction, key) + } t.sendEvent(nftypes.TypeStart, conn, ruleID) } diff --git a/client/firewall/uspfilter/filter.go b/client/firewall/uspfilter/filter.go index 7eef49e31..990630ee4 100644 --- a/client/firewall/uspfilter/filter.go +++ b/client/firewall/uspfilter/filter.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "errors" "fmt" "net" @@ -27,7 +28,12 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" ) -const layerTypeAll = 0 +const ( + layerTypeAll = 0 + + // ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation + ipTCPHeaderMinSize = 40 +) const ( // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. @@ -36,6 +42,9 @@ const ( // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" + // EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic. + EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING" + // EnvForceUserspaceRouter forces userspace routing even if native routing is available. EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" @@ -50,6 +59,12 @@ const ( var errNatNotSupported = errors.New("nat not supported with userspace firewall") +// serviceKey represents a protocol/port combination for netstack service registry +type serviceKey struct { + protocol gopacket.LayerType + port uint16 +} + // RuleSet is a set of rules grouped by a string key type RuleSet map[string]PeerRule @@ -109,6 +124,17 @@ type Manager struct { dnatMappings map[netip.Addr]netip.Addr dnatMutex sync.RWMutex dnatBiMap *biDNATMap + + portDNATEnabled atomic.Bool + portDNATRules []portDNATRule + portDNATMutex sync.RWMutex + + netstackServices map[serviceKey]struct{} + netstackServiceMutex sync.RWMutex + + mtu uint16 + mssClampValue uint16 + mssClampEnabled bool } // decoder for packages @@ -122,19 +148,21 @@ type decoder struct { icmp6 layers.ICMPv6 decoded []gopacket.LayerType parser *gopacket.DecodingLayerParser + + dnatOrigPort uint16 } // Create userspace firewall manager constructor -func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - return create(iface, nil, disableServerRoutes, flowLogger) +func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + return create(iface, nil, disableServerRoutes, flowLogger, mtu) } -func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { if nativeFirewall == nil { return nil, errors.New("native firewall is nil") } - mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger) + mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu) if err != nil { return nil, err } @@ -142,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall. return mgr, nil } -func parseCreateEnv() (bool, bool) { - var disableConntrack, enableLocalForwarding bool +func parseCreateEnv() (bool, bool, bool) { + var disableConntrack, enableLocalForwarding, disableMSSClamping bool var err error if val := os.Getenv(EnvDisableConntrack); val != "" { disableConntrack, err = strconv.ParseBool(val) @@ -162,12 +190,18 @@ func parseCreateEnv() (bool, bool) { log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err) } } + if val := os.Getenv(EnvDisableMSSClamping); val != "" { + disableMSSClamping, err = strconv.ParseBool(val) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err) + } + } - return disableConntrack, enableLocalForwarding + return disableConntrack, enableLocalForwarding, disableMSSClamping } -func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) { - disableConntrack, enableLocalForwarding := parseCreateEnv() +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) { + disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv() m := &Manager{ decoders: sync.Pool{ @@ -196,13 +230,19 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe netstack: netstack.IsEnabled(), localForwarding: enableLocalForwarding, dnatMappings: make(map[netip.Addr]netip.Addr), + portDNATRules: []portDNATRule{}, + netstackServices: make(map[serviceKey]struct{}), + mtu: mtu, } m.routingEnabled.Store(false) + if !disableMSSClamping { + m.mssClampEnabled = true + m.mssClampValue = mtu - ipTCPHeaderMinSize + } if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { return nil, fmt.Errorf("update local IPs: %w", err) } - if disableConntrack { log.Info("conntrack is disabled") } else { @@ -210,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger) m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger) } - - // netstack needs the forwarder for local traffic if m.netstack && m.localForwarding { if err := m.initForwarder(); err != nil { log.Errorf("failed to initialize forwarder: %v", err) } } - if err := iface.SetFilter(m); err != nil { return nil, fmt.Errorf("set filter: %w", err) } @@ -320,7 +357,7 @@ func (m *Manager) initForwarder() error { return errors.New("forwarding not supported") } - forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack) + forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu) if err != nil { m.routingEnabled.Store(false) return fmt.Errorf("create forwarder: %w", err) @@ -626,11 +663,20 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool { return false } - if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { - return true + switch d.decoded[1] { + case layers.LayerTypeUDP: + if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) { + return true + } + case layers.LayerTypeTCP: + // Clamp MSS on all TCP SYN packets, including those from local IPs. + // SNATed routed traffic may appear as local IP but still requires clamping. + if m.mssClampEnabled { + m.clampTCPMSS(packetData, d) + } } - m.trackOutbound(d, srcIP, dstIP, size) + m.trackOutbound(d, srcIP, dstIP, packetData, size) m.translateOutboundDNAT(packetData, d) return false @@ -674,14 +720,117 @@ func getTCPFlags(tcp *layers.TCP) uint8 { return flags } -func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) { +// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation. +// Both sides advertise their MSS during connection establishment, so we need to clamp both. +func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool { + if !d.tcp.SYN { + return false + } + if len(d.tcp.Options) == 0 { + return false + } + + mssOptionIndex := -1 + var currentMSS uint16 + for i, opt := range d.tcp.Options { + if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 { + currentMSS = binary.BigEndian.Uint16(opt.OptionData) + if currentMSS > m.mssClampValue { + mssOptionIndex = i + break + } + } + } + + if mssOptionIndex == -1 { + return false + } + + ipHeaderSize := int(d.ip4.IHL) * 4 + if ipHeaderSize < 20 { + return false + } + + if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) { + return false + } + + m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue) + return true +} + +func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool { + tcpHeaderStart := ipHeaderSize + tcpOptionsStart := tcpHeaderStart + 20 + + optOffset := tcpOptionsStart + for j := 0; j < mssOptionIndex; j++ { + switch d.tcp.Options[j].OptionType { + case layers.TCPOptionKindEndList, layers.TCPOptionKindNop: + optOffset++ + default: + optOffset += 2 + len(d.tcp.Options[j].OptionData) + } + } + + mssValueOffset := optOffset + 2 + binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue) + + m.recalculateTCPChecksum(packetData, d, tcpHeaderStart) + return true +} + +func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) { + tcpLayer := packetData[tcpHeaderStart:] + tcpLength := len(packetData) - tcpHeaderStart + + tcpLayer[16] = 0 + tcpLayer[17] = 0 + + var pseudoSum uint32 + pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1]) + pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3]) + pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1]) + pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3]) + pseudoSum += uint32(d.ip4.Protocol) + pseudoSum += uint32(tcpLength) + + var sum uint32 = pseudoSum + for i := 0; i < tcpLength-1; i += 2 { + sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1]) + } + if tcpLength%2 == 1 { + sum += uint32(tcpLayer[tcpLength-1]) << 8 + } + + for sum > 0xFFFF { + sum = (sum & 0xFFFF) + (sum >> 16) + } + + checksum := ^uint16(sum) + binary.BigEndian.PutUint16(tcpLayer[16:18], checksum) +} + +func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) { transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + origPort := m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size) + if origPort == 0 { + break + } + if err := m.rewriteUDPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite UDP port: %v", err) + } case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + origPort := m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size) + if origPort == 0 { + break + } + if err := m.rewriteTCPPort(packetData, d, origPort, sourcePortOffset); err != nil { + m.logger.Error1("failed to rewrite TCP port: %v", err) + } case layers.LayerTypeICMPv4: m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size) } @@ -691,13 +840,15 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt transport := d.decoded[1] switch transport { case layers.LayerTypeUDP: - m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size) + m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size, d.dnatOrigPort) case layers.LayerTypeTCP: flags := getTCPFlags(&d.tcp) - m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size) + m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size, d.dnatOrigPort) case layers.LayerTypeICMPv4: m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size) } + + d.dnatOrigPort = 0 } // udpHooksDrop checks if any UDP hooks should drop the packet @@ -759,10 +910,20 @@ func (m *Manager) filterInbound(packetData []byte, size int) bool { return false } + // TODO: optimize port DNAT by caching matched rules in conntrack + if translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP); translated { + // Re-decode after port DNAT translation to update port information + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + m.logger.Error1("failed to re-decode packet after port DNAT: %v", err) + return true + } + srcIP, dstIP = m.extractIPs(d) + } + if translated := m.translateInboundReverse(packetData, d); translated { // Re-decode after translation to get original addresses if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { - m.logger.Error1("Failed to re-decode packet after reverse DNAT: %v", err) + m.logger.Error1("failed to re-decode packet after reverse DNAT: %v", err) return true } srcIP, dstIP = m.extractIPs(d) @@ -807,9 +968,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet return true } - // If requested we pass local traffic to internal interfaces to the forwarder. - // netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder. - if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) { + if m.shouldForward(d, dstIP) { return m.handleForwardedLocalTraffic(packetData) } @@ -1243,3 +1402,86 @@ func (m *Manager) DisableRouting() error { return nil } + +// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port +func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + m.netstackServices[key] = struct{}{} + m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType) + m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices)) +} + +// UnregisterNetstackService removes a service from the netstack registry +func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) { + m.netstackServiceMutex.Lock() + defer m.netstackServiceMutex.Unlock() + layerType := m.protocolToLayerType(protocol) + key := serviceKey{protocol: layerType, port: port} + delete(m.netstackServices, key) + m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port) +} + +// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use +func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType { + switch protocol { + case nftypes.TCP: + return layers.LayerTypeTCP + case nftypes.UDP: + return layers.LayerTypeUDP + case nftypes.ICMP: + return layers.LayerTypeICMPv4 + default: + return gopacket.LayerType(0) // Invalid/unknown + } +} + +// shouldForward determines if a packet should be forwarded to the forwarder. +// The forwarder handles routing packets to the native OS network stack. +// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly. +func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool { + // not enabled, never forward + if !m.localForwarding { + return false + } + + // netstack always needs to forward because it's lacking a native interface + // exception for registered netstack services, those should go to netstack listeners + if m.netstack { + return !m.hasMatchingNetstackService(d) + } + + // traffic to our other local interfaces (not NetBird IP) - always forward + if dstIP != m.wgIface.Address().IP { + return true + } + + // traffic to our NetBird IP, not netstack mode - send to netstack listeners + return false +} + +// hasMatchingNetstackService checks if there's a registered netstack service for this packet +func (m *Manager) hasMatchingNetstackService(d *decoder) bool { + if len(d.decoded) < 2 { + return false + } + + var dstPort uint16 + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + dstPort = uint16(d.udp.DstPort) + default: + return false + } + + key := serviceKey{protocol: d.decoded[1], port: dstPort} + m.netstackServiceMutex.RLock() + _, exists := m.netstackServices[key] + m.netstackServiceMutex.RUnlock() + + return exists +} diff --git a/client/firewall/uspfilter/filter_bench_test.go b/client/firewall/uspfilter/filter_bench_test.go index 0cffcc1a7..5a2d0410f 100644 --- a/client/firewall/uspfilter/filter_bench_test.go +++ b/client/firewall/uspfilter/filter_bench_test.go @@ -17,6 +17,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) defer b.Cleanup(func() { require.NoError(b, manager.Close(nil)) }) @@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { } } -// generateTCPPacketWithFlags creates a TCP packet with specific flags -func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte { - b.Helper() - - ipv4 := &layers.IPv4{ - TTL: 64, - Version: 4, - SrcIP: srcIP, - DstIP: dstIP, - Protocol: layers.IPProtocolTCP, - } - - tcp := &layers.TCP{ - SrcPort: layers.TCPPort(srcPort), - DstPort: layers.TCPPort(dstPort), - } - - // Set TCP flags - tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0 - tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0 - tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0 - tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0 - tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0 - - require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4)) - - buf := gopacket.NewSerializeBuffer() - opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} - require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) - return buf.Bytes() -} - func BenchmarkRouteACLs(b *testing.B) { manager := setupRoutedManager(b, "10.10.0.100/16") @@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) { } } } + +// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound. +// This shows the overhead difference between the common case (non-SYN packets, fast path) +// and the rare case (SYN packets that need clamping, expensive path). +func BenchmarkMSSClamping(b *testing.B) { + scenarios := []struct { + name string + description string + genPacket func(*testing.B, net.IP, net.IP) []byte + frequency string + }{ + { + name: "syn_needs_clamp", + description: "SYN packet needing MSS clamping", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + frequency: "~0.1% of traffic - EXPENSIVE", + }, + { + name: "syn_no_clamp_needed", + description: "SYN packet with already-small MSS", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200) + }, + frequency: "~0.05% of traffic", + }, + { + name: "tcp_ack", + description: "Non-SYN TCP packet (ACK, data transfer)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + frequency: "~60-70% of traffic - FAST PATH", + }, + { + name: "tcp_psh_ack", + description: "TCP data packet (PSH+ACK)", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck)) + }, + frequency: "~10-20% of traffic - FAST PATH", + }, + { + name: "udp", + description: "UDP packet", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + frequency: "~20-30% of traffic - FAST PATH", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled +// for the common case (non-SYN TCP packets). +func BenchmarkMSSClampingOverhead(b *testing.B) { + scenarios := []struct { + name string + enabled bool + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "disabled_tcp_ack", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "enabled_tcp_ack", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "disabled_syn_needs_clamp", + enabled: false, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "enabled_syn_needs_clamp", + enabled: true, + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = sc.enabled + if sc.enabled { + manager.mssClampValue = 1240 + } + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases +func BenchmarkMSSClampingMemory(b *testing.B) { + scenarios := []struct { + name string + genPacket func(*testing.B, net.IP, net.IP) []byte + }{ + { + name: "tcp_ack_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck)) + }, + }, + { + name: "syn_needs_clamp", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460) + }, + }, + { + name: "udp_fast_path", + genPacket: func(b *testing.B, src, dst net.IP) []byte { + return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP) + }, + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + manager.mssClampEnabled = true + manager.mssClampValue = 1240 + + srcIP := net.ParseIP("100.64.0.2") + dstIP := net.ParseIP("8.8.8.8") + packet := sc.genPacket(b, srcIP, dstIP) + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + manager.filterOutbound(packet, len(packet)) + } + }) + } +} + +func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte { + b.Helper() + + ip := &layers.IPv4{ + Version: 4, + IHL: 5, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Seq: 1000, + Window: 65535, + } + + require.NoError(b, tcp.SetNetworkLayerForChecksum(ip)) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{}))) + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/filter_filter_test.go b/client/firewall/uspfilter/filter_filter_test.go index 73f3face8..eb5aa3343 100644 --- a/client/firewall/uspfilter/filter_filter_test.go +++ b/client/firewall/uspfilter/filter_filter_test.go @@ -12,6 +12,7 @@ import ( wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/mocks" "github.com/netbirdio/netbird/client/iface/wgaddr" @@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) require.NotNil(t, manager) @@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(tb, err) require.NoError(tb, manager.EnableRouting()) require.NotNil(tb, manager) @@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) { }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index bac06814d..c56a078fc 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -1,6 +1,7 @@ package uspfilter import ( + "encoding/binary" "fmt" "net" "net/netip" @@ -17,6 +18,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + nbiface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/netflow" @@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) @@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface, false, flowLogger) + manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() @@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) time.Sleep(time.Second) @@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) manager.udpTracker.Close() // Close the existing tracker @@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock, false, flowLogger) + manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU) require.NoError(t, err) t.Cleanup(func() { require.NoError(t, manager.Close(nil)) @@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) { require.Equal(t, tc.expected, isAllowed, tc.desc) } } + +func TestMSSClamping(t *testing.T) { + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() wgaddr.Address { + return wgaddr.Address{ + IP: netip.MustParseAddr("100.10.0.100"), + Network: netip.MustParsePrefix("100.10.0.0/16"), + } + }, + } + + manager, err := Create(ifaceMock, false, flowLogger, 1280) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default") + expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize) + require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40") + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + srcIP := net.ParseIP("100.10.0.2") + dstIP := net.ParseIP("8.8.8.8") + + t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType)) + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40") + }) + + t.Run("SYN packet with low MSS unchanged", func(t *testing.T) { + lowMSS := uint16(1200) + packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified") + }) + + t.Run("SYN-ACK packet gets clamped", func(t *testing.T) { + highMSS := uint16(1460) + packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Len(t, d.tcp.Options, 1, "Should have MSS option") + actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData) + require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped") + }) + + t.Run("Non-SYN packet unchanged", func(t *testing.T) { + packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck)) + + manager.filterOutbound(packet, len(packet)) + + d := parsePacket(t, packet) + require.Empty(t, d.tcp.Options, "ACK packet should have no options") + }) +} + +func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + SYN: true, + ACK: true, + Window: 65535, + Options: []layers.TCPOption{ + { + OptionType: layers.TCPOptionKindMSS, + OptionLength: 4, + OptionData: binary.BigEndian.AppendUint16(nil, mss), + }, + }, + } + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} + +func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte { + tb.Helper() + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocolTCP, + SrcIP: srcIP, + DstIP: dstIP, + } + + tcpLayer := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + Window: 65535, + } + + if flags&uint16(conntrack.TCPSyn) != 0 { + tcpLayer.SYN = true + } + if flags&uint16(conntrack.TCPAck) != 0 { + tcpLayer.ACK = true + } + if flags&uint16(conntrack.TCPFin) != 0 { + tcpLayer.FIN = true + } + if flags&uint16(conntrack.TCPRst) != 0 { + tcpLayer.RST = true + } + if flags&uint16(conntrack.TCPPush) != 0 { + tcpLayer.PSH = true + } + + err := tcpLayer.SetNetworkLayerForChecksum(ipLayer) + require.NoError(tb, err) + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true} + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{})) + require.NoError(tb, err) + + return buf.Bytes() +} diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go index 42a3e0800..00cb3f1df 100644 --- a/client/firewall/uspfilter/forwarder/forwarder.go +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -45,7 +45,7 @@ type Forwarder struct { netstack bool } -func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) { +func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{ @@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow HandleLocal: false, }) - mtu, err := iface.GetDevice().MTU() - if err != nil { - return nil, fmt.Errorf("get MTU: %w", err) - } nicID := tcpip.NICID(1) endpoint := &endpoint{ logger: logger, @@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow } if err := s.CreateNIC(nicID, endpoint); err != nil { - return nil, fmt.Errorf("failed to create NIC: %v", err) + return nil, fmt.Errorf("create NIC: %v", err) } protoAddr := tcpip.ProtocolAddress{ diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go index d146de5e4..55743d975 100644 --- a/client/firewall/uspfilter/forwarder/udp.go +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -49,7 +49,7 @@ type idleConn struct { conn *udpPacketConn } -func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { +func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder { ctx, cancel := context.WithCancel(context.Background()) f := &udpForwarder{ logger: logger, diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go index 5614e2ec3..139f702f2 100644 --- a/client/firewall/uspfilter/log/log.go +++ b/client/firewall/uspfilter/log/log.go @@ -50,6 +50,8 @@ type logMessage struct { arg4 any arg5 any arg6 any + arg7 any + arg8 any } // Logger is a high-performance, non-blocking logger @@ -94,7 +96,6 @@ func (l *Logger) SetLevel(level Level) { log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) } - func (l *Logger) Error(format string) { if l.level.Load() >= uint32(LevelError) { select { @@ -185,6 +186,15 @@ func (l *Logger) Debug2(format string, arg1, arg2 any) { } } +func (l *Logger) Debug3(format string, arg1, arg2, arg3 any) { + if l.level.Load() >= uint32(LevelDebug) { + select { + case l.msgChannel <- logMessage{level: LevelDebug, format: format, arg1: arg1, arg2: arg2, arg3: arg3}: + default: + } + } +} + func (l *Logger) Trace1(format string, arg1 any) { if l.level.Load() >= uint32(LevelTrace) { select { @@ -239,6 +249,16 @@ func (l *Logger) Trace6(format string, arg1, arg2, arg3, arg4, arg5, arg6 any) { } } +// Trace8 logs a trace message with 8 arguments (8 placeholder in format string) +func (l *Logger) Trace8(format string, arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8 any) { + if l.level.Load() >= uint32(LevelTrace) { + select { + case l.msgChannel <- logMessage{level: LevelTrace, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4, arg5: arg5, arg6: arg6, arg7: arg7, arg8: arg8}: + default: + } + } +} + func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { *buf = (*buf)[:0] *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") @@ -260,6 +280,12 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { argCount++ if msg.arg6 != nil { argCount++ + if msg.arg7 != nil { + argCount++ + if msg.arg8 != nil { + argCount++ + } + } } } } @@ -283,6 +309,10 @@ func (l *Logger) formatMessage(buf *[]byte, msg logMessage) { formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5) case 6: formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6) + case 7: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7) + case 8: + formatted = fmt.Sprintf(msg.format, msg.arg1, msg.arg2, msg.arg3, msg.arg4, msg.arg5, msg.arg6, msg.arg7, msg.arg8) } *buf = append(*buf, formatted...) @@ -390,4 +420,4 @@ func (l *Logger) Stop(ctx context.Context) error { case <-done: return nil } -} \ No newline at end of file +} diff --git a/client/firewall/uspfilter/nat.go b/client/firewall/uspfilter/nat.go index 27b752531..13567872e 100644 --- a/client/firewall/uspfilter/nat.go +++ b/client/firewall/uspfilter/nat.go @@ -5,7 +5,9 @@ import ( "errors" "fmt" "net/netip" + "slices" + "github.com/google/gopacket" "github.com/google/gopacket/layers" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -13,6 +15,21 @@ import ( var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT") +var ( + errInvalidIPHeaderLength = errors.New("invalid IP header length") +) + +const ( + // Port offsets in TCP/UDP headers + sourcePortOffset = 0 + destinationPortOffset = 2 + + // IP address offsets in IPv4 header + sourceIPOffset = 12 + destinationIPOffset = 16 +) + +// ipv4Checksum calculates IPv4 header checksum. func ipv4Checksum(header []byte) uint16 { if len(header) < 20 { return 0 @@ -52,6 +69,7 @@ func ipv4Checksum(header []byte) uint16 { return ^uint16(sum) } +// icmpChecksum calculates ICMP checksum. func icmpChecksum(data []byte) uint16 { var sum1, sum2, sum3, sum4 uint32 i := 0 @@ -89,11 +107,21 @@ func icmpChecksum(data []byte) uint16 { return ^uint16(sum) } +// biDNATMap maintains bidirectional DNAT mappings. type biDNATMap struct { forward map[netip.Addr]netip.Addr reverse map[netip.Addr]netip.Addr } +// portDNATRule represents a port-specific DNAT rule. +type portDNATRule struct { + protocol gopacket.LayerType + origPort uint16 + targetPort uint16 + targetIP netip.Addr +} + +// newBiDNATMap creates a new bidirectional DNAT mapping structure. func newBiDNATMap() *biDNATMap { return &biDNATMap{ forward: make(map[netip.Addr]netip.Addr), @@ -101,11 +129,13 @@ func newBiDNATMap() *biDNATMap { } } +// set adds a bidirectional DNAT mapping between original and translated addresses. func (b *biDNATMap) set(original, translated netip.Addr) { b.forward[original] = translated b.reverse[translated] = original } +// delete removes a bidirectional DNAT mapping for the given original address. func (b *biDNATMap) delete(original netip.Addr) { if translated, exists := b.forward[original]; exists { delete(b.forward, original) @@ -113,19 +143,25 @@ func (b *biDNATMap) delete(original netip.Addr) { } } +// getTranslated returns the translated address for a given original address. func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) { translated, exists := b.forward[original] return translated, exists } +// getOriginal returns the original address for a given translated address. func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) { original, exists := b.reverse[translated] return original, exists } +// AddInternalDNATMapping adds a 1:1 IP address mapping for internal DNAT translation. func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error { - if !originalAddr.IsValid() || !translatedAddr.IsValid() { - return fmt.Errorf("invalid IP addresses") + if !originalAddr.IsValid() { + return fmt.Errorf("invalid original IP address") + } + if !translatedAddr.IsValid() { + return fmt.Errorf("invalid translated IP address") } if m.localipmanager.IsLocalIP(translatedAddr) { @@ -135,7 +171,6 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr m.dnatMutex.Lock() defer m.dnatMutex.Unlock() - // Initialize both maps together if either is nil if m.dnatMappings == nil || m.dnatBiMap == nil { m.dnatMappings = make(map[netip.Addr]netip.Addr) m.dnatBiMap = newBiDNATMap() @@ -151,7 +186,7 @@ func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr return nil } -// RemoveInternalDNATMapping removes a 1:1 IP address mapping +// RemoveInternalDNATMapping removes a 1:1 IP address mapping. func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { m.dnatMutex.Lock() defer m.dnatMutex.Unlock() @@ -169,7 +204,7 @@ func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error { return nil } -// getDNATTranslation returns the translated address if a mapping exists +// getDNATTranslation returns the translated address if a mapping exists. func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return addr, false @@ -181,7 +216,7 @@ func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) { return translated, exists } -// findReverseDNATMapping finds original address for return traffic +// findReverseDNATMapping finds original address for return traffic. func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) { if !m.dnatEnabled.Load() { return translatedAddr, false @@ -193,16 +228,12 @@ func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, return original, exists } -// translateOutboundDNAT applies DNAT translation to outbound packets +// translateOutboundDNAT applies DNAT translation to outbound packets. func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) translatedIP, exists := m.getDNATTranslation(dstIP) @@ -210,8 +241,8 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil { - m.logger.Error1("Failed to rewrite packet destination: %v", err) + if err := m.rewritePacketIP(packetData, d, translatedIP, destinationIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet destination: %v", err) return false } @@ -219,16 +250,12 @@ func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool { return true } -// translateInboundReverse applies reverse DNAT to inbound return traffic +// translateInboundReverse applies reverse DNAT to inbound return traffic. func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { if !m.dnatEnabled.Load() { return false } - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { - return false - } - srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) originalIP, exists := m.findReverseDNATMapping(srcIP) @@ -236,8 +263,8 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return false } - if err := m.rewritePacketSource(packetData, d, originalIP); err != nil { - m.logger.Error1("Failed to rewrite packet source: %v", err) + if err := m.rewritePacketIP(packetData, d, originalIP, sourceIPOffset); err != nil { + m.logger.Error1("failed to rewrite packet source: %v", err) return false } @@ -245,21 +272,21 @@ func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool { return true } -// rewritePacketDestination replaces destination IP in the packet -func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { +// rewritePacketIP replaces an IP address (source or destination) in the packet and updates checksums. +func (m *Manager) rewritePacketIP(packetData []byte, d *decoder, newIP netip.Addr, ipOffset int) error { + if !newIP.Is4() { return ErrIPv4Only } - var oldDst [4]byte - copy(oldDst[:], packetData[16:20]) - newDst := newIP.As4() + var oldIP [4]byte + copy(oldIP[:], packetData[ipOffset:ipOffset+4]) + newIPBytes := newIP.As4() - copy(packetData[16:20], newDst[:]) + copy(packetData[ipOffset:ipOffset+4], newIPBytes[:]) ipHeaderLen := int(d.ip4.IHL) * 4 if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") + return errInvalidIPHeaderLength } binary.BigEndian.PutUint16(packetData[10:12], 0) @@ -269,44 +296,9 @@ func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP if len(d.decoded) > 1 { switch d.decoded[1] { case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) + m.updateTCPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:]) - case layers.LayerTypeICMPv4: - m.updateICMPChecksum(packetData, ipHeaderLen) - } - } - - return nil -} - -// rewritePacketSource replaces the source IP address in the packet -func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error { - if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() { - return ErrIPv4Only - } - - var oldSrc [4]byte - copy(oldSrc[:], packetData[12:16]) - newSrc := newIP.As4() - - copy(packetData[12:16], newSrc[:]) - - ipHeaderLen := int(d.ip4.IHL) * 4 - if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { - return fmt.Errorf("invalid IP header length") - } - - binary.BigEndian.PutUint16(packetData[10:12], 0) - ipChecksum := ipv4Checksum(packetData[:ipHeaderLen]) - binary.BigEndian.PutUint16(packetData[10:12], ipChecksum) - - if len(d.decoded) > 1 { - switch d.decoded[1] { - case layers.LayerTypeTCP: - m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) - case layers.LayerTypeUDP: - m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:]) + m.updateUDPChecksum(packetData, ipHeaderLen, oldIP[:], newIPBytes[:]) case layers.LayerTypeICMPv4: m.updateICMPChecksum(packetData, ipHeaderLen) } @@ -315,6 +307,7 @@ func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip return nil } +// updateTCPChecksum updates TCP checksum after IP address change per RFC 1624. func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { tcpStart := ipHeaderLen if len(packetData) < tcpStart+18 { @@ -327,6 +320,7 @@ func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateUDPChecksum updates UDP checksum after IP address change per RFC 1624. func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) { udpStart := ipHeaderLen if len(packetData) < udpStart+8 { @@ -344,6 +338,7 @@ func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, n binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) } +// updateICMPChecksum recalculates ICMP checksum after packet modification. func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { icmpStart := ipHeaderLen if len(packetData) < icmpStart+8 { @@ -356,7 +351,7 @@ func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) { binary.BigEndian.PutUint16(icmpData[2:4], checksum) } -// incrementalUpdate performs incremental checksum update per RFC 1624 +// incrementalUpdate performs incremental checksum update per RFC 1624. func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { sum := uint32(^oldChecksum) @@ -391,7 +386,7 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 { return ^uint16(sum) } -// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding) +// AddDNATRule adds outbound DNAT rule for forwarding external traffic to NetBird network. func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) { if m.nativeFirewall == nil { return nil, errNatNotSupported @@ -399,10 +394,184 @@ func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) return m.nativeFirewall.AddDNATRule(rule) } -// DeleteDNATRule deletes a DNAT rule (delegates to native firewall) +// DeleteDNATRule deletes outbound DNAT rule. func (m *Manager) DeleteDNATRule(rule firewall.Rule) error { if m.nativeFirewall == nil { return errNatNotSupported } return m.nativeFirewall.DeleteDNATRule(rule) } + +// addPortRedirection adds a port redirection rule. +func (m *Manager) addPortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + rule := portDNATRule{ + protocol: protocol, + origPort: sourcePort, + targetPort: targetPort, + targetIP: targetIP, + } + + m.portDNATRules = append(m.portDNATRules, rule) + m.portDNATEnabled.Store(true) + + return nil +} + +// AddInboundDNAT adds an inbound DNAT rule redirecting traffic from NetBird peers to local services. +func (m *Manager) AddInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.addPortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// removePortRedirection removes a port redirection rule. +func (m *Manager) removePortRedirection(targetIP netip.Addr, protocol gopacket.LayerType, sourcePort, targetPort uint16) error { + m.portDNATMutex.Lock() + defer m.portDNATMutex.Unlock() + + m.portDNATRules = slices.DeleteFunc(m.portDNATRules, func(rule portDNATRule) bool { + return rule.protocol == protocol && rule.origPort == sourcePort && rule.targetPort == targetPort && rule.targetIP.Compare(targetIP) == 0 + }) + + if len(m.portDNATRules) == 0 { + m.portDNATEnabled.Store(false) + } + + return nil +} + +// RemoveInboundDNAT removes an inbound DNAT rule. +func (m *Manager) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Protocol, sourcePort, targetPort uint16) error { + var layerType gopacket.LayerType + switch protocol { + case firewall.ProtocolTCP: + layerType = layers.LayerTypeTCP + case firewall.ProtocolUDP: + layerType = layers.LayerTypeUDP + default: + return fmt.Errorf("unsupported protocol: %s", protocol) + } + + return m.removePortRedirection(localAddr, layerType, sourcePort, targetPort) +} + +// translateInboundPortDNAT applies port-specific DNAT translation to inbound packets. +func (m *Manager) translateInboundPortDNAT(packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool { + if !m.portDNATEnabled.Load() { + return false + } + + switch d.decoded[1] { + case layers.LayerTypeTCP: + dstPort := uint16(d.tcp.DstPort) + return m.applyPortRule(packetData, d, srcIP, dstIP, dstPort, layers.LayerTypeTCP, m.rewriteTCPPort) + case layers.LayerTypeUDP: + dstPort := uint16(d.udp.DstPort) + return m.applyPortRule(packetData, d, netip.Addr{}, dstIP, dstPort, layers.LayerTypeUDP, m.rewriteUDPPort) + default: + return false + } +} + +type portRewriteFunc func(packetData []byte, d *decoder, newPort uint16, portOffset int) error + +func (m *Manager) applyPortRule(packetData []byte, d *decoder, srcIP, dstIP netip.Addr, port uint16, protocol gopacket.LayerType, rewriteFn portRewriteFunc) bool { + m.portDNATMutex.RLock() + defer m.portDNATMutex.RUnlock() + + for _, rule := range m.portDNATRules { + if rule.protocol != protocol || rule.targetIP.Compare(dstIP) != 0 { + continue + } + + if rule.targetPort == port && rule.targetIP.Compare(srcIP) == 0 { + return false + } + + if rule.origPort != port { + continue + } + + if err := rewriteFn(packetData, d, rule.targetPort, destinationPortOffset); err != nil { + m.logger.Error1("failed to rewrite port: %v", err) + return false + } + d.dnatOrigPort = rule.origPort + return true + } + return false +} + +// rewriteTCPPort rewrites a TCP port (source or destination) and updates checksum. +func (m *Manager) rewriteTCPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + tcpStart := ipHeaderLen + if len(packetData) < tcpStart+4 { + return fmt.Errorf("packet too short for TCP header") + } + + portStart := tcpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + if len(packetData) >= tcpStart+18 { + checksumOffset := tcpStart + 16 + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + + return nil +} + +// rewriteUDPPort rewrites a UDP port (source or destination) and updates checksum. +func (m *Manager) rewriteUDPPort(packetData []byte, d *decoder, newPort uint16, portOffset int) error { + ipHeaderLen := int(d.ip4.IHL) * 4 + if ipHeaderLen < 20 || ipHeaderLen > len(packetData) { + return errInvalidIPHeaderLength + } + + udpStart := ipHeaderLen + if len(packetData) < udpStart+8 { + return fmt.Errorf("packet too short for UDP header") + } + + portStart := udpStart + portOffset + oldPort := binary.BigEndian.Uint16(packetData[portStart : portStart+2]) + binary.BigEndian.PutUint16(packetData[portStart:portStart+2], newPort) + + checksumOffset := udpStart + 6 + if len(packetData) >= udpStart+8 { + oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2]) + if oldChecksum != 0 { + var oldPortBytes, newPortBytes [2]byte + binary.BigEndian.PutUint16(oldPortBytes[:], oldPort) + binary.BigEndian.PutUint16(newPortBytes[:], newPort) + + newChecksum := incrementalUpdate(oldChecksum, oldPortBytes[:], newPortBytes[:]) + binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum) + } + } + + return nil +} diff --git a/client/firewall/uspfilter/nat_bench_test.go b/client/firewall/uspfilter/nat_bench_test.go index 16dba682e..d2599e577 100644 --- a/client/firewall/uspfilter/nat_bench_test.go +++ b/client/firewall/uspfilter/nat_bench_test.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) { func BenchmarkDNATConcurrency(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) { b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) { func BenchmarkDNATMemoryAllocations(b *testing.B) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(b, err) defer func() { require.NoError(b, manager.Close(nil)) @@ -414,3 +415,127 @@ func BenchmarkChecksumOptimizations(b *testing.B) { } }) } + +// BenchmarkPortDNAT measures the performance of port DNAT operations +func BenchmarkPortDNAT(b *testing.B) { + scenarios := []struct { + name string + proto layers.IPProtocol + setupDNAT bool + useMatchPort bool + description string + }{ + { + name: "tcp_inbound_dnat_match", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: true, + description: "TCP inbound port DNAT translation (22 → 22022)", + }, + { + name: "tcp_inbound_dnat_nomatch", + proto: layers.IPProtocolTCP, + setupDNAT: true, + useMatchPort: false, + description: "TCP inbound with DNAT configured but no port match", + }, + { + name: "tcp_inbound_no_dnat", + proto: layers.IPProtocolTCP, + setupDNAT: false, + useMatchPort: false, + description: "TCP inbound without DNAT (baseline)", + }, + { + name: "udp_inbound_dnat_match", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: true, + description: "UDP inbound port DNAT translation (5353 → 22054)", + }, + { + name: "udp_inbound_dnat_nomatch", + proto: layers.IPProtocolUDP, + setupDNAT: true, + useMatchPort: false, + description: "UDP inbound with DNAT configured but no port match", + }, + { + name: "udp_inbound_no_dnat", + proto: layers.IPProtocolUDP, + setupDNAT: false, + useMatchPort: false, + description: "UDP inbound without DNAT (baseline)", + }, + } + + for _, sc := range scenarios { + b.Run(sc.name, func(b *testing.B) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(b, err) + defer func() { + require.NoError(b, manager.Close(nil)) + }() + + // Set logger to error level to reduce noise during benchmarking + manager.SetLogLevel(log.ErrorLevel) + defer func() { + // Restore to info level after benchmark + manager.SetLogLevel(log.InfoLevel) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + var origPort, targetPort, testPort uint16 + if sc.proto == layers.IPProtocolTCP { + origPort, targetPort = 22, 22022 + } else { + origPort, targetPort = 5353, 22054 + } + + if sc.useMatchPort { + testPort = origPort + } else { + testPort = 443 // Different port + } + + // Setup port DNAT mapping if needed + if sc.setupDNAT { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(sc.proto), origPort, targetPort) + require.NoError(b, err) + } + + // Pre-establish inbound connection for outbound reverse test + if sc.setupDNAT && sc.useMatchPort { + inboundPacket := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, origPort) + manager.filterInbound(inboundPacket, 0) + } + + b.ResetTimer() + b.ReportAllocs() + + // Benchmark inbound DNAT translation + b.Run("inbound", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh packet each time + packet := generateDNATTestPacket(b, clientIP, localAddr, sc.proto, 54321, testPort) + manager.filterInbound(packet, 0) + } + }) + + // Benchmark outbound reverse DNAT translation (only if DNAT is set up and port matches) + if sc.setupDNAT && sc.useMatchPort { + b.Run("outbound_reverse", func(b *testing.B) { + for i := 0; i < b.N; i++ { + // Create fresh return packet (from target port) + packet := generateDNATTestPacket(b, localAddr, clientIP, sc.proto, targetPort, 54321) + manager.filterOutbound(packet, 0) + } + }) + } + }) + } +} diff --git a/client/firewall/uspfilter/nat_test.go b/client/firewall/uspfilter/nat_test.go index 710abd445..400d61020 100644 --- a/client/firewall/uspfilter/nat_test.go +++ b/client/firewall/uspfilter/nat_test.go @@ -8,6 +8,8 @@ import ( "github.com/google/gopacket/layers" "github.com/stretchr/testify/require" + firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -15,7 +17,7 @@ import ( func TestDNATTranslationCorrectness(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -99,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder { func TestDNATMappingManagement(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }, false, flowLogger) + }, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) defer func() { require.NoError(t, manager.Close(nil)) @@ -143,3 +145,111 @@ func TestDNATMappingManagement(t *testing.T) { err = manager.RemoveInternalDNATMapping(originalIP) require.Error(t, err, "Should error when removing non-existent mapping") } + +func TestInboundPortDNAT(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + testCases := []struct { + name string + protocol layers.IPProtocol + sourcePort uint16 + targetPort uint16 + }{ + {"TCP SSH", layers.IPProtocolTCP, 22, 22022}, + {"UDP DNS", layers.IPProtocolUDP, 5353, 22054}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + err := manager.AddInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + + inboundPacket := generateDNATTestPacket(t, clientIP, localAddr, tc.protocol, 54321, tc.sourcePort) + d := parsePacket(t, inboundPacket) + + translated := manager.translateInboundPortDNAT(inboundPacket, d, clientIP, localAddr) + require.True(t, translated, "Inbound packet should be translated") + + d = parsePacket(t, inboundPacket) + var dstPort uint16 + switch tc.protocol { + case layers.IPProtocolTCP: + dstPort = uint16(d.tcp.DstPort) + case layers.IPProtocolUDP: + dstPort = uint16(d.udp.DstPort) + } + + require.Equal(t, tc.targetPort, dstPort, "Destination port should be rewritten to target port") + + err = manager.RemoveInboundDNAT(localAddr, protocolToFirewall(tc.protocol), tc.sourcePort, tc.targetPort) + require.NoError(t, err) + }) + } +} + +func TestInboundPortDNATNegative(t *testing.T) { + manager, err := Create(&IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + }, false, flowLogger, iface.DefaultMTU) + require.NoError(t, err) + defer func() { + require.NoError(t, manager.Close(nil)) + }() + + localAddr := netip.MustParseAddr("100.0.2.175") + clientIP := netip.MustParseAddr("100.0.169.249") + + err = manager.AddInboundDNAT(localAddr, firewall.ProtocolTCP, 22, 22022) + require.NoError(t, err) + + testCases := []struct { + name string + protocol layers.IPProtocol + srcIP netip.Addr + dstIP netip.Addr + srcPort uint16 + dstPort uint16 + }{ + {"Wrong port", layers.IPProtocolTCP, clientIP, localAddr, 54321, 80}, + {"Wrong IP", layers.IPProtocolTCP, clientIP, netip.MustParseAddr("100.64.0.99"), 54321, 22}, + {"Wrong protocol", layers.IPProtocolUDP, clientIP, localAddr, 54321, 22}, + {"ICMP", layers.IPProtocolICMPv4, clientIP, localAddr, 0, 0}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + packet := generateDNATTestPacket(t, tc.srcIP, tc.dstIP, tc.protocol, tc.srcPort, tc.dstPort) + d := parsePacket(t, packet) + + translated := manager.translateInboundPortDNAT(packet, d, tc.srcIP, tc.dstIP) + require.False(t, translated, "Packet should NOT be translated for %s", tc.name) + + d = parsePacket(t, packet) + if tc.protocol == layers.IPProtocolTCP { + require.Equal(t, tc.dstPort, uint16(d.tcp.DstPort), "Port should remain unchanged") + } else if tc.protocol == layers.IPProtocolUDP { + require.Equal(t, tc.dstPort, uint16(d.udp.DstPort), "Port should remain unchanged") + } + }) + } +} + +func protocolToFirewall(proto layers.IPProtocol) firewall.Protocol { + switch proto { + case layers.IPProtocolTCP: + return firewall.ProtocolTCP + case layers.IPProtocolUDP: + return firewall.ProtocolUDP + default: + return firewall.ProtocolALL + } +} diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go index c75c0249d..c46a6581d 100644 --- a/client/firewall/uspfilter/tracer.go +++ b/client/firewall/uspfilter/tracer.go @@ -16,25 +16,33 @@ type PacketStage int const ( StageReceived PacketStage = iota + StageInboundPortDNAT + StageInbound1to1NAT StageConntrack StagePeerACL StageRouting StageRouteACL StageForwarding StageCompleted + StageOutbound1to1NAT + StageOutboundPortReverse ) const msgProcessingCompleted = "Processing completed" func (s PacketStage) String() string { return map[PacketStage]string{ - StageReceived: "Received", - StageConntrack: "Connection Tracking", - StagePeerACL: "Peer ACL", - StageRouting: "Routing", - StageRouteACL: "Route ACL", - StageForwarding: "Forwarding", - StageCompleted: "Completed", + StageReceived: "Received", + StageInboundPortDNAT: "Inbound Port DNAT", + StageInbound1to1NAT: "Inbound 1:1 NAT", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + StageOutbound1to1NAT: "Outbound 1:1 NAT", + StageOutboundPortReverse: "Outbound DNAT Reverse", }[s] } @@ -261,6 +269,10 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa } func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace { + if m.handleInboundDNAT(trace, packetData, d, &srcIP, &dstIP) { + return trace + } + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { return trace } @@ -400,7 +412,16 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str } func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { - // will create or update the connection state + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageCompleted, "Packet dropped - decode error", false) + return trace + } + + m.handleOutboundDNAT(trace, packetData, d) + dropped := m.filterOutbound(packetData, 0) if dropped { trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) @@ -409,3 +430,199 @@ func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTr } return trace } + +func (m *Manager) handleInboundDNAT(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP *netip.Addr) bool { + portDNATApplied := m.traceInboundPortDNAT(trace, packetData, d) + if portDNATApplied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInboundPortDNAT, "Failed to re-decode after port DNAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + trace.DestinationPort = m.getDestPort(d) + } + + nat1to1Applied := m.traceInbound1to1NAT(trace, packetData, d) + if nat1to1Applied { + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageInbound1to1NAT, "Failed to re-decode after 1:1 NAT", false) + return true + } + *srcIP, *dstIP = m.extractIPs(d) + } + + return false +} + +func (m *Manager) traceInboundPortDNAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageInboundPortDNAT, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageInboundPortDNAT, "Not IPv4, skipping port DNAT", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageInboundPortDNAT, "No transport layer, skipping port DNAT", true) + return false + } + + protocol := d.decoded[1] + if protocol != layers.LayerTypeTCP && protocol != layers.LayerTypeUDP { + trace.AddResult(StageInboundPortDNAT, "Not TCP/UDP, skipping port DNAT", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + var originalPort uint16 + if protocol == layers.LayerTypeTCP { + originalPort = uint16(d.tcp.DstPort) + } else { + originalPort = uint16(d.udp.DstPort) + } + + translated := m.translateInboundPortDNAT(packetData, d, srcIP, dstIP) + if translated { + ipHeaderLen := int((packetData[0] & 0x0F) * 4) + translatedPort := uint16(packetData[ipHeaderLen+2])<<8 | uint16(packetData[ipHeaderLen+3]) + + protoStr := "TCP" + if protocol == layers.LayerTypeUDP { + protoStr = "UDP" + } + msg := fmt.Sprintf("%s port DNAT applied: %s:%d -> %s:%d", protoStr, dstIP, originalPort, dstIP, translatedPort) + trace.AddResult(StageInboundPortDNAT, msg, true) + return true + } + + trace.AddResult(StageInboundPortDNAT, "No matching port DNAT rule", true) + return false +} + +func (m *Manager) traceInbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageInbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + + translated := m.translateInboundReverse(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatBiMap.getOriginal(srcIP) + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT reverse applied: %s -> %s", srcIP, translatedIP) + trace.AddResult(StageInbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageInbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) handleOutboundDNAT(trace *PacketTrace, packetData []byte, d *decoder) { + m.traceOutbound1to1NAT(trace, packetData, d) + m.traceOutboundPortReverse(trace, packetData, d) +} + +func (m *Manager) traceOutbound1to1NAT(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.dnatEnabled.Load() { + trace.AddResult(StageOutbound1to1NAT, "1:1 NAT not enabled", true) + return false + } + + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + translated := m.translateOutboundDNAT(packetData, d) + if translated { + m.dnatMutex.RLock() + translatedIP, exists := m.dnatMappings[dstIP] + m.dnatMutex.RUnlock() + + if exists { + msg := fmt.Sprintf("1:1 NAT applied: %s -> %s", dstIP, translatedIP) + trace.AddResult(StageOutbound1to1NAT, msg, true) + return true + } + } + + trace.AddResult(StageOutbound1to1NAT, "No matching 1:1 NAT rule", true) + return false +} + +func (m *Manager) traceOutboundPortReverse(trace *PacketTrace, packetData []byte, d *decoder) bool { + if !m.portDNATEnabled.Load() { + trace.AddResult(StageOutboundPortReverse, "Port DNAT not enabled", true) + return false + } + + if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 { + trace.AddResult(StageOutboundPortReverse, "Not IPv4, skipping port reverse", true) + return false + } + + if len(d.decoded) < 2 { + trace.AddResult(StageOutboundPortReverse, "No transport layer, skipping port reverse", true) + return false + } + + srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]}) + dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]}) + + var origPort uint16 + transport := d.decoded[1] + switch transport { + case layers.LayerTypeTCP: + srcPort := uint16(d.tcp.SrcPort) + dstPort := uint16(d.tcp.DstPort) + conn, exists := m.tcpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("TCP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + case layers.LayerTypeUDP: + srcPort := uint16(d.udp.SrcPort) + dstPort := uint16(d.udp.DstPort) + conn, exists := m.udpTracker.GetConnection(dstIP, dstPort, srcIP, srcPort) + if exists { + origPort = uint16(conn.DNATOrigPort.Load()) + } + if origPort != 0 { + msg := fmt.Sprintf("UDP DNAT reverse (tracked connection): %s:%d -> %s:%d", srcIP, srcPort, srcIP, origPort) + trace.AddResult(StageOutboundPortReverse, msg, true) + return true + } + default: + trace.AddResult(StageOutboundPortReverse, "Not TCP/UDP, skipping port reverse", true) + return false + } + + trace.AddResult(StageOutboundPortReverse, "No tracked connection for DNAT reverse", true) + return false +} + +func (m *Manager) getDestPort(d *decoder) uint16 { + if len(d.decoded) < 2 { + return 0 + } + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.DstPort) + default: + return 0 + } +} diff --git a/client/firewall/uspfilter/tracer_test.go b/client/firewall/uspfilter/tracer_test.go index 46c115787..d9f9f1aa8 100644 --- a/client/firewall/uspfilter/tracer_test.go +++ b/client/firewall/uspfilter/tracer_test.go @@ -10,6 +10,7 @@ import ( fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/wgaddr" ) @@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) { }, } - m, err := Create(ifaceMock, false, flowLogger) + m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU) require.NoError(t, err) if !statefulMode { @@ -104,6 +105,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -126,6 +129,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -153,6 +158,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -179,6 +186,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -204,6 +213,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -228,6 +239,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -246,6 +259,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageRouteACL, @@ -264,6 +279,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StageCompleted, @@ -287,6 +304,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageCompleted, }, @@ -301,6 +320,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageOutbound1to1NAT, + StageOutboundPortReverse, StageCompleted, }, expectedAllow: true, @@ -319,6 +340,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -340,6 +363,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -362,6 +387,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -382,6 +409,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageConntrack, StageRouting, StagePeerACL, @@ -406,6 +435,8 @@ func TestTracePacket(t *testing.T) { }, expectedStages: []PacketStage{ StageReceived, + StageInboundPortDNAT, + StageInbound1to1NAT, StageRouting, StagePeerACL, StageCompleted, diff --git a/client/grpc/dialer.go b/client/grpc/dialer.go index 6aff53b92..7763f2417 100644 --- a/client/grpc/dialer.go +++ b/client/grpc/dialer.go @@ -4,12 +4,15 @@ import ( "context" "crypto/tls" "crypto/x509" + "errors" + "fmt" "runtime" "time" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/keepalive" @@ -17,6 +20,9 @@ import ( "github.com/netbirdio/netbird/util/embeddedroots" ) +// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready +var ErrConnectionShutdown = errors.New("connection shutdown before ready") + // Backoff returns a backoff configuration for gRPC calls func Backoff(ctx context.Context) backoff.BackOff { b := backoff.NewExponentialBackOff() @@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff { return backoff.WithContext(b, ctx) } +// waitForConnectionReady blocks until the connection becomes ready or fails. +// Returns an error if the connection times out, is cancelled, or enters shutdown state. +func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error { + conn.Connect() + + state := conn.GetState() + for state != connectivity.Ready && state != connectivity.Shutdown { + if !conn.WaitForStateChange(ctx, state) { + return fmt.Errorf("wait state change from %s: %w", state, ctx.Err()) + } + state = conn.GetState() + } + + if state == connectivity.Shutdown { + return ErrConnectionShutdown + } + + return nil +} + // CreateConnection creates a gRPC client connection with the appropriate transport options. // The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal"). func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) { @@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone })) } - connCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - conn, err := grpc.DialContext( - connCtx, + conn, err := grpc.NewClient( addr, transportOption, WithCustomDialer(tlsEnabled, component), - grpc.WithBlock(), grpc.WithKeepaliveParams(keepalive.ClientParameters{ Time: 30 * time.Second, Timeout: 10 * time.Second, }), ) if err != nil { - log.Printf("DialContext error: %v", err) + return nil, fmt.Errorf("new client: %w", err) + } + + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + if err := waitForConnectionReady(ctx, conn); err != nil { + _ = conn.Close() return nil, err } diff --git a/client/grpc/dialer_generic.go b/client/grpc/dialer_generic.go index 96f347c64..479575996 100644 --- a/client/grpc/dialer_generic.go +++ b/client/grpc/dialer_generic.go @@ -18,7 +18,7 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { +func WithCustomDialer(_ bool, _ string) grpc.DialOption { return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { if runtime.GOOS == "linux" { currentUser, err := user.Current() @@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption { conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr) if err != nil { - log.Errorf("Failed to dial: %s", err) return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err) } return conn, nil diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index daf4979ce..638245bf7 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/acl/mocks" "github.com/netbirdio/netbird/client/internal/netflow" @@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) @@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { }).AnyTimes() ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() - fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false) + fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU) require.NoError(t, err) defer func() { err = fw.Close(nil) diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index da4f16c8d..8ca760742 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow deviceCode.VerificationURIComplete = deviceCode.VerificationURI } + if d.providerConfig.LoginHint != "" { + deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint) + if deviceCode.VerificationURI != "" { + deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint) + } + } + return deviceCode, err } +func appendLoginHint(uri, loginHint string) string { + if uri == "" || loginHint == "" { + return uri + } + + parsedURL, err := url.Parse(uri) + if err != nil { + log.Debugf("failed to parse verification URI for login_hint: %v", err) + return uri + } + + query := parsedURL.Query() + query.Set("login_hint", loginHint) + parsedURL.RawQuery = query.Encode() + + return parsedURL.String() +} + func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { form := url.Values{} form.Add("client_id", d.providerConfig.ClientID) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 4458f600c..9fbd6cf5f 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string { // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } - pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint) if err != nil { - // fallback to device code flow log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debug("falling back to device code flow") - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } return pkceFlow, nil } // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow -func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } + + pkceFlowInfo.ProviderConfig.LoginHint = hint + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow -func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { switch s, ok := gstatus.FromError(err); { @@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } + deviceFlowInfo.ProviderConfig.LoginHint = hint + return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 8741e8636..738d3e34f 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn params = append(params, oauth2.SetAuthURLParam("max_age", "0")) } } + if p.providerConfig.LoginHint != "" { + params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint)) + } authURL := p.oAuthConfig.AuthCodeURL(state, params...) diff --git a/client/internal/connect.go b/client/internal/connect.go index 7c4a3e574..6993442be 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -293,15 +293,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan } <-engineCtx.Done() + c.engineMutex.Lock() - if c.engine != nil && c.engine.wgInterface != nil { - log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name()) - if err := c.engine.Stop(); err != nil { + engine := c.engine + c.engine = nil + c.engineMutex.Unlock() + + if engine != nil && engine.wgInterface != nil { + log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name()) + if err := engine.Stop(); err != nil { log.Errorf("Failed to stop engine: %v", err) } - c.engine = nil } - c.engineMutex.Unlock() c.statusRecorder.ClientTeardown() backOff.Reset() @@ -386,19 +389,12 @@ func (c *ConnectClient) Status() StatusType { } func (c *ConnectClient) Stop() error { - if c == nil { - return nil + engine := c.Engine() + if engine != nil { + if err := engine.Stop(); err != nil { + return fmt.Errorf("stop engine: %w", err) + } } - c.engineMutex.Lock() - defer c.engineMutex.Unlock() - - if c.engine == nil { - return nil - } - if err := c.engine.Stop(); err != nil { - return fmt.Errorf("stop engine: %w", err) - } - return nil } diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index 442f54e71..fbec29ce3 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided. iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided. nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided. +resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided. +scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided. resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder. config.txt: Anonymized configuration information of the NetBird client. network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules. @@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information: The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing. For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged. + +DNS Configuration +The debug bundle includes platform-specific DNS configuration files: + +resolv.conf (Unix systems): +- Contains DNS resolver configuration from /etc/resolv.conf +- Includes nameserver entries, search domains, and resolver options +- All IP addresses and domain names are anonymized following the same rules as other files + +scutil_dns.txt (macOS only): +- Contains detailed DNS configuration from scutil --dns +- Shows DNS configuration for all network interfaces +- Includes search domains, nameservers, and DNS resolver settings +- All IP addresses and domain names are anonymized ` const ( @@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() { if err := g.addFirewallRules(); err != nil { log.Errorf("failed to add firewall rules to debug bundle: %v", err) } + + if err := g.addDNSInfo(); err != nil { + log.Errorf("failed to add DNS info to debug bundle: %v", err) + } } func (g *BundleGenerator) addReadme() error { diff --git a/client/internal/debug/debug_darwin.go b/client/internal/debug/debug_darwin.go new file mode 100644 index 000000000..91e10214f --- /dev/null +++ b/client/internal/debug/debug_darwin.go @@ -0,0 +1,53 @@ +//go:build darwin && !ios + +package debug + +import ( + "bytes" + "context" + "fmt" + "os/exec" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + if err := g.addScutilDNS(); err != nil { + log.Errorf("failed to add scutil DNS output: %v", err) + } + + return nil +} + +func (g *BundleGenerator) addScutilDNS() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "scutil", "--dns") + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("execute scutil --dns: %w", err) + } + + if len(bytes.TrimSpace(output)) == 0 { + return fmt.Errorf("no scutil DNS output") + } + + content := string(output) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil { + return fmt.Errorf("add scutil DNS output to zip: %w", err) + } + + return nil +} diff --git a/client/internal/debug/debug_mobile.go b/client/internal/debug/debug_mobile.go index c00c65132..3c1745ff3 100644 --- a/client/internal/debug/debug_mobile.go +++ b/client/internal/debug/debug_mobile.go @@ -5,3 +5,7 @@ package debug func (g *BundleGenerator) addRoutes() error { return nil } + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_nondarwin.go b/client/internal/debug/debug_nondarwin.go new file mode 100644 index 000000000..dfc2eace5 --- /dev/null +++ b/client/internal/debug/debug_nondarwin.go @@ -0,0 +1,16 @@ +//go:build unix && !darwin && !android + +package debug + +import ( + log "github.com/sirupsen/logrus" +) + +// addDNSInfo collects and adds DNS configuration information to the archive +func (g *BundleGenerator) addDNSInfo() error { + if err := g.addResolvConf(); err != nil { + log.Errorf("failed to add resolv.conf: %v", err) + } + + return nil +} diff --git a/client/internal/debug/debug_nonunix.go b/client/internal/debug/debug_nonunix.go new file mode 100644 index 000000000..18d017050 --- /dev/null +++ b/client/internal/debug/debug_nonunix.go @@ -0,0 +1,7 @@ +//go:build !unix + +package debug + +func (g *BundleGenerator) addDNSInfo() error { + return nil +} diff --git a/client/internal/debug/debug_unix.go b/client/internal/debug/debug_unix.go new file mode 100644 index 000000000..7e8a74eb0 --- /dev/null +++ b/client/internal/debug/debug_unix.go @@ -0,0 +1,29 @@ +//go:build unix && !android + +package debug + +import ( + "fmt" + "os" + "strings" +) + +const resolvConfPath = "/etc/resolv.conf" + +func (g *BundleGenerator) addResolvConf() error { + data, err := os.ReadFile(resolvConfPath) + if err != nil { + return fmt.Errorf("read %s: %w", resolvConfPath, err) + } + + content := string(data) + if g.anonymize { + content = g.anonymizer.AnonymizeString(content) + } + + if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil { + return fmt.Errorf("add resolv.conf to zip: %w", err) + } + + return nil +} diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index 6bd29801d..7f7d06130 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct { Scope string // UseIDToken indicates if the id token should be used for authentication UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index b06ba73ab..71badf0d4 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -13,6 +13,7 @@ import ( "strings" log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool { } func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - var err error - - if err := stateManager.UpdateState(&ShutdownState{}); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } - var ( searchDomains []string matchDomains []string ) - err = s.recordSystemDNSSettings(true) - if err != nil { + if err := s.recordSystemDNSSettings(true); err != nil { log.Errorf("unable to update record of System's DNS config: %s", err.Error()) } if config.RouteAll { searchDomains = append(searchDomains, "\"\"") - err = s.addLocalDNS() - if err != nil { - log.Infof("failed to enable split DNS") + if err := s.addLocalDNS(); err != nil { + log.Warnf("failed to add local DNS: %v", err) } + s.updateState(stateManager) } for _, dConf := range config.Domains { @@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * } matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + var err error if len(matchDomains) != 0 { err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort) } else { @@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add match domains: %w", err) } + s.updateState(stateManager) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) if len(searchDomains) != 0 { @@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * if err != nil { return fmt.Errorf("add search domains: %w", err) } + s.updateState(stateManager) if err := s.flushDNSCache(); err != nil { log.Errorf("failed to flush DNS cache: %v", err) @@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager * return nil } +func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (s *systemConfigurator) string() string { return "scutil" } @@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error { func (s *systemConfigurator) addLocalDNS() error { if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { if err := s.recordSystemDNSSettings(true); err != nil { - log.Errorf("Unable to get system DNS configuration") return fmt.Errorf("recordSystemDNSSettings(): %w", err) } } localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) - if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 { - err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort) - if err != nil { - return fmt.Errorf("couldn't add local network DNS conf: %w", err) - } - } else { + if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 { log.Info("Not enabling local DNS server") + return nil + } + + if err := s.addSearchDomains( + localKey, + strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, + ); err != nil { + return fmt.Errorf("add search domains: %w", err) } return nil diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go new file mode 100644 index 000000000..c4efd17b0 --- /dev/null +++ b/client/internal/dns/host_darwin_test.go @@ -0,0 +1,111 @@ +//go:build !ios + +package dns + +import ( + "context" + "net/netip" + "os/exec" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/statemanager" +) + +func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { + if testing.Short() { + t.Skip("skipping scutil integration test in short mode") + } + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + defer func() { + require.NoError(t, sm.Stop(context.Background())) + }() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + config := HostDNSConfig{ + ServerIP: netip.MustParseAddr("100.64.0.1"), + ServerPort: 53, + RouteAll: true, + Domains: []DomainConfig{ + {Domain: "example.com", MatchOnly: true}, + }, + } + + err := configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + require.NoError(t, sm.PersistState(context.Background())) + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + defer func() { + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + }() + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + if exists { + t.Logf("Key %s exists before cleanup", key) + } + } + + sm2 := statemanager.New(stateFile) + sm2.RegisterState(&ShutdownState{}) + err = sm2.LoadState(&ShutdownState{}) + require.NoError(t, err) + + state := sm2.GetState(&ShutdownState{}) + if state == nil { + t.Skip("State not saved, skipping cleanup test") + } + + shutdownState, ok := state.(*ShutdownState) + require.True(t, ok) + + err = shutdownState.Cleanup() + require.NoError(t, err) + + for _, key := range []string{searchKey, matchKey, localKey} { + exists, err := checkDNSKeyExists(key) + require.NoError(t, err) + assert.False(t, exists, "Key %s should NOT exist after cleanup", key) + } +} + +func checkDNSKeyExists(key string) (bool, error) { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") + output, err := cmd.CombinedOutput() + if err != nil { + if strings.Contains(string(output), "No such key") { + return false, nil + } + return false, err + } + return !strings.Contains(string(output), "No such key"), nil +} + +func removeTestDNSKey(key string) error { + cmd := exec.Command(scutilPath) + cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") + _, err := cmd.CombinedOutput() + return err +} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 74111d335..01b7edc48 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -179,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) var searchDomains, matchDomains []string for _, dConf := range config.Domains { @@ -212,13 +206,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager r.nrptEntryCount = 0 } - if err := stateManager.UpdateState(&ShutdownState{ - Guid: r.guid, - GPO: r.gpo, - NRPTEntryCount: r.nrptEntryCount, - }); err != nil { - log.Errorf("failed to update shutdown state: %s", err) - } + r.updateState(stateManager) if err := r.updateSearchDomains(searchDomains); err != nil { return fmt.Errorf("update search domains: %w", err) @@ -229,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { + if err := stateManager.UpdateState(&ShutdownState{ + Guid: r.guid, + GPO: r.gpo, + NRPTEntryCount: r.nrptEntryCount, + }); err != nil { + log.Errorf("failed to update shutdown state: %s", err) + } +} + func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil { return fmt.Errorf("adding dns setup for all failed: %w", err) diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 8cb886203..afaf0579f 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface { // DefaultServer dns server object type DefaultServer struct { - ctx context.Context - ctxCancel context.CancelFunc + ctx context.Context + ctxCancel context.CancelFunc + shutdownWg sync.WaitGroup // disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running. // This is different from ServiceEnable=false from management which completely disables the DNS service. disableSys bool @@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr { // Stop stops the server func (s *DefaultServer) Stop() { s.ctxCancel() + s.shutdownWg.Wait() s.mux.Lock() defer s.mux.Unlock() @@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.applyHostConfig() + s.shutdownWg.Add(1) go func() { - // persist dns state right away + defer s.shutdownWg.Done() if err := s.stateManager.PersistState(s.ctx); err != nil { log.Errorf("Failed to persist dns state: %v", err) } diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index 11575d500..451b83f92 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface, false, flowLogger) + pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/dns/unclean_shutdown_darwin.go b/client/internal/dns/unclean_shutdown_darwin.go index 9bbdd2b56..f51b5cf8d 100644 --- a/client/internal/dns/unclean_shutdown_darwin.go +++ b/client/internal/dns/unclean_shutdown_darwin.go @@ -7,6 +7,7 @@ import ( ) type ShutdownState struct { + CreatedKeys []string } func (s *ShutdownState) Name() string { @@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error { return fmt.Errorf("create host manager: %w", err) } + for _, key := range s.CreatedKeys { + manager.createdKeys[key] = struct{}{} + } + if err := manager.restoreUncleanShutdownDNS(); err != nil { return fmt.Errorf("restore unclean shutdown dns: %w", err) } diff --git a/client/internal/dnsfwd/cache_test.go b/client/internal/dnsfwd/cache_test.go index c23f0f31d..44ebe290b 100644 --- a/client/internal/dnsfwd/cache_test.go +++ b/client/internal/dnsfwd/cache_test.go @@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) { t.Fatalf("expected cache miss, got=%v ok=%v", got, ok) } } - diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 7a262fa4c..aef16a8cf 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -14,6 +14,7 @@ import ( "github.com/hashicorp/go-multierror" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" @@ -33,7 +34,7 @@ type firewaller interface { } type DNSForwarder struct { - listenAddress string + listenAddress netip.AddrPort ttl uint32 statusRecorder *peer.Status @@ -47,9 +48,11 @@ type DNSForwarder struct { firewall firewaller resolver resolver cache *cache + + wgIface wgIface } -func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder { +func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder { log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) return &DNSForwarder{ listenAddress: listenAddress, @@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat statusRecorder: statusRecorder, resolver: net.DefaultResolver, cache: newCache(), + wgIface: wgIface, } } func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { - log.Infof("starting DNS forwarder on address=%s", f.listenAddress) + var netstackNet *netstack.Net + if f.wgIface != nil { + netstackNet = f.wgIface.GetNet() + } + + addrDesc := f.listenAddress.String() + if netstackNet != nil { + addrDesc = fmt.Sprintf("netstack %s", f.listenAddress) + } + log.Infof("starting DNS forwarder on address=%s", addrDesc) + + udpLn, err := f.createUDPListener(netstackNet) + if err != nil { + return fmt.Errorf("create UDP listener: %w", err) + } + + tcpLn, err := f.createTCPListener(netstackNet) + if err != nil { + return fmt.Errorf("create TCP listener: %w", err) + } - // UDP server mux := dns.NewServeMux() f.mux = mux mux.HandleFunc(".", f.handleDNSQueryUDP) f.dnsServer = &dns.Server{ - Addr: f.listenAddress, - Net: "udp", - Handler: mux, + PacketConn: udpLn, + Handler: mux, } - // TCP server tcpMux := dns.NewServeMux() f.tcpMux = tcpMux tcpMux.HandleFunc(".", f.handleDNSQueryTCP) f.tcpServer = &dns.Server{ - Addr: f.listenAddress, - Net: "tcp", - Handler: tcpMux, + Listener: tcpLn, + Handler: tcpMux, } f.UpdateDomains(entries) @@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error { errCh := make(chan error, 2) go func() { - log.Infof("DNS UDP listener running on %s", f.listenAddress) - errCh <- f.dnsServer.ListenAndServe() + log.Infof("DNS UDP listener running on %s", addrDesc) + errCh <- f.dnsServer.ActivateAndServe() }() go func() { - log.Infof("DNS TCP listener running on %s", f.listenAddress) - errCh <- f.tcpServer.ListenAndServe() + log.Infof("DNS TCP listener running on %s", addrDesc) + errCh <- f.tcpServer.ActivateAndServe() }() - // return the first error we get (e.g. bind failure or shutdown) return <-errCh } +func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) { + if netstackNet != nil { + return netstackNet.ListenUDPAddrPort(f.listenAddress) + } + + return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress)) +} + +func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) { + if netstackNet != nil { + return netstackNet.ListenTCPAddrPort(f.listenAddress) + } + + return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress)) +} + func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { f.mutex.Lock() defer f.mutex.Unlock() diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index c1c95a2c1..4d0b96a75 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil) } - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString(tt.configuredDomain) @@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { mockResolver := &MockResolver{} // Set up forwarder - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Create entries and track sets @@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Configure a single domain @@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) d, err := domain.FromString(tt.configured) require.NoError(t, err) @@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { func TestDNSForwarder_TCPTruncation(t *testing.T) { // Test that large UDP responses are truncated with TC bit set mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, _ := domain.FromString("example.com") @@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) { // a subsequent upstream failure still returns a successful response from cache. func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Verifies that cache normalization works across casing and trailing dot variations. func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("ExAmPlE.CoM") @@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver // Set up complex overlapping patterns @@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { mockFirewall := &MockFirewall{} mockResolver := &MockResolver{} - forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil) forwarder.resolver = mockResolver d, err := domain.FromString("example.com") @@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { func TestDNSForwarder_EmptyQuery(t *testing.T) { // Test handling of malformed query with no questions - forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{}) + forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil) query := &dns.Msg{} // Don't set any question diff --git a/client/internal/dnsfwd/manager.go b/client/internal/dnsfwd/manager.go index a3a4ba40f..58b88d9ef 100644 --- a/client/internal/dnsfwd/manager.go +++ b/client/internal/dnsfwd/manager.go @@ -4,27 +4,34 @@ import ( "context" "fmt" "net" - "sync" + "net/netip" + "os" + "strconv" "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/tun/netstack" nberrors "github.com/netbirdio/netbird/client/errors" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/wgaddr" + nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" ) -var ( - // ListenPort is the port that the DNS forwarder listens on. It has been used by the client peers also - listenPort uint16 = 5353 - listenPortMu sync.RWMutex +const ( + dnsTTL = 60 + envServerPort = "NB_DNS_FORWARDER_PORT" ) -const ( - dnsTTL = 60 //seconds -) +// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder. +type wgIface interface { + GetNet() *netstack.Net + Address() wgaddr.Address +} // ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list. type ForwarderEntry struct { @@ -36,28 +43,30 @@ type ForwarderEntry struct { type Manager struct { firewall firewall.Manager statusRecorder *peer.Status + wgIface wgIface + serverPort uint16 fwRules []firewall.Rule tcpRules []firewall.Rule dnsForwarder *DNSForwarder } -func ListenPort() uint16 { - listenPortMu.RLock() - defer listenPortMu.RUnlock() - return listenPort -} +func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager { + serverPort := nbdns.ForwarderServerPort + if envPort := os.Getenv(envServerPort); envPort != "" { + if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 { + serverPort = uint16(port) + log.Infof("using custom DNS forwarder port from %s: %d", envServerPort, serverPort) + } else { + log.Warnf("invalid %s value %q, using default %d", envServerPort, envPort, nbdns.ForwarderServerPort) + } + } -func SetListenPort(port uint16) { - listenPortMu.Lock() - listenPort = port - listenPortMu.Unlock() -} - -func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager { return &Manager{ firewall: fw, statusRecorder: statusRecorder, + wgIface: wgIface, + serverPort: serverPort, } } @@ -71,7 +80,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error { return err } - m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort()), dnsTTL, m.firewall, m.statusRecorder) + localAddr := m.wgIface.Address().IP + + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS UDP DNAT rule: %v", err) + } else { + log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) + } + + if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + log.Warnf("failed to add DNS TCP DNAT rule: %v", err) + } else { + log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort) + } + } + + listenAddress := netip.AddrPortFrom(localAddr, m.serverPort) + m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface) + go func() { if err := m.dnsForwarder.Listen(fwdEntries); err != nil { // todo handle close error if it is exists @@ -96,6 +123,20 @@ func (m *Manager) Stop(ctx context.Context) error { } var mErr *multierror.Error + + localAddr := m.wgIface.Address().IP + if localAddr.IsValid() && m.firewall != nil { + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err)) + } + + if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil { + mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err)) + } + } + + m.unregisterNetstackServices() + if err := m.dropDNSFirewall(); err != nil { mErr = multierror.Append(mErr, err) } @@ -111,7 +152,7 @@ func (m *Manager) Stop(ctx context.Context) error { func (m *Manager) allowDNSFirewall() error { dport := &firewall.Port{ IsRange: false, - Values: []uint16{ListenPort()}, + Values: []uint16{m.serverPort}, } if m.firewall == nil { @@ -120,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error { dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add udp firewall rule: %w", err) } - m.fwRules = dnsRules tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "") if err != nil { - log.Errorf("failed to add allow DNS router rules, err: %v", err) - return err + return fmt.Errorf("add tcp firewall rule: %w", err) } + + if err := m.firewall.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + m.fwRules = dnsRules m.tcpRules = tcpRules + m.registerNetstackServices() + return nil } +func (m *Manager) registerNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + RegisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.RegisterNetstackService(nftypes.TCP, m.serverPort) + registrar.RegisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + +func (m *Manager) unregisterNetstackServices() { + if netstackNet := m.wgIface.GetNet(); netstackNet != nil { + if registrar, ok := m.firewall.(interface { + UnregisterNetstackService(protocol nftypes.Protocol, port uint16) + }); ok { + registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort) + registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort) + log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort) + } + } +} + func (m *Manager) dropDNSFirewall() error { var mErr *multierror.Error for _, rule := range m.fwRules { diff --git a/client/internal/engine.go b/client/internal/engine.go index ac559d2b4..63a5aaca2 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -150,6 +150,8 @@ type Engine struct { // syncMsgMux is used to guarantee sequential Management Service message processing syncMsgMux *sync.Mutex + // sshMux protects sshServer field access + sshMux sync.Mutex config *EngineConfig mobileDep MobileDependency @@ -205,11 +207,12 @@ type Engine struct { updateManager *updatemanager.UpdateManager // WireGuard interface monitor - wgIfaceMonitor *WGIfaceMonitor - wgIfaceMonitorWg sync.WaitGroup + wgIfaceMonitor *WGIfaceMonitor - // dns forwarder port - dnsFwdPort uint16 + // shutdownWg tracks all long-running goroutines to ensure clean shutdown + shutdownWg sync.WaitGroup + + probeStunTurn *relay.StunTurnProbe } // Peer is an instance of the Connection Peer @@ -252,7 +255,7 @@ func NewEngine( statusRecorder: statusRecorder, checks: checks, connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit), - dnsFwdPort: dnsfwd.ListenPort(), + probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL), } sm := profilemanager.NewServiceManager("") @@ -304,17 +307,12 @@ func (e *Engine) Stop() error { e.ingressGatewayMgr = nil } + e.stopDNSForwarder() + if e.routeManager != nil { e.routeManager.Stop(e.stateManager) } - if e.dnsForwardMgr != nil { - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil - } - if e.srWatcher != nil { e.srWatcher.Close() } @@ -335,10 +333,6 @@ func (e *Engine) Stop() error { e.cancel() } - // very ugly but we want to remove peers from the WireGuard interface first before removing interface. - // Removing peers happens in the conn.Close() asynchronously - time.Sleep(500 * time.Millisecond) - e.close() // stop flow manager after wg interface is gone @@ -346,8 +340,6 @@ func (e *Engine) Stop() error { e.flowManager.Close() } - log.Infof("stopped Netbird Engine") - ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -358,12 +350,52 @@ func (e *Engine) Stop() error { log.Errorf("failed to persist state: %v", err) } - // Stop WireGuard interface monitor and wait for it to exit - e.wgIfaceMonitorWg.Wait() + timeout := e.calculateShutdownTimeout() + log.Debugf("waiting for goroutines to finish with timeout: %v", timeout) + shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil { + log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout) + } + + log.Infof("stopped Netbird Engine") return nil } +// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s. +func (e *Engine) calculateShutdownTimeout() time.Duration { + peerCount := len(e.peerStore.PeersPubKey()) + + baseTimeout := 10 * time.Second + perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond + timeout := baseTimeout + perPeerTimeout + + maxTimeout := 30 * time.Second + if timeout > maxTimeout { + timeout = maxTimeout + } + + return timeout +} + +// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout. +func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error { + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services // Connections to remote peers are not established here. // However, they will be established once an event with a list of peers to connect to will be received from Management Service @@ -493,14 +525,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) // monitor WireGuard interface lifecycle and restart engine on changes e.wgIfaceMonitor = NewWGIfaceMonitor() - e.wgIfaceMonitorWg.Add(1) + e.shutdownWg.Add(1) go func() { - defer e.wgIfaceMonitorWg.Done() + defer e.shutdownWg.Done() if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart { log.Infof("WireGuard interface monitor: %s, restarting engine", err) - e.restartEngine() + e.triggerClientRestart() } else if err != nil { log.Warnf("WireGuard interface monitor: %s", err) } @@ -529,7 +561,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil @@ -697,9 +729,11 @@ func (e *Engine) removeAllPeers() error { func (e *Engine) removePeer(peerKey string) error { log.Debugf("removing peer from engine %s", peerKey) + e.sshMux.Lock() if !isNil(e.sshServer) { e.sshServer.RemoveAuthorizedKey(peerKey) } + e.sshMux.Unlock() e.connMgr.RemovePeerConn(peerKey) @@ -935,6 +969,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { log.Warnf("running SSH server on %s is not supported", runtime.GOOS) return nil } + e.sshMux.Lock() // start SSH server if it wasn't running if isNil(e.sshServer) { listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort) @@ -942,34 +977,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort) } // nil sshServer means it has not yet been started - var err error - e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr) - + server, err := e.sshServerFunc(e.config.SSHKey, listenAddr) if err != nil { + e.sshMux.Unlock() return fmt.Errorf("create ssh server: %w", err) } + + e.sshServer = server + e.sshMux.Unlock() + go func() { // blocking - err = e.sshServer.Start() + err = server.Start() if err != nil { // will throw error when we stop it even if it is a graceful stop log.Debugf("stopped SSH server with error %v", err) } - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() + e.sshMux.Lock() e.sshServer = nil + e.sshMux.Unlock() log.Infof("stopped SSH server") }() } else { + e.sshMux.Unlock() log.Debugf("SSH server is already running") } - } else if !isNil(e.sshServer) { - // Disable SSH server request, so stop it if it was running - err := e.sshServer.Stop() - if err != nil { - log.Warnf("failed to stop SSH server %v", err) + } else { + e.sshMux.Lock() + if !isNil(e.sshServer) { + // Disable SSH server request, so stop it if it was running + err := e.sshServer.Stop() + if err != nil { + log.Warnf("failed to stop SSH server %v", err) + } + e.sshServer = nil } - e.sshServer = nil + e.sshMux.Unlock() } return nil } @@ -1006,7 +1049,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() info, err := system.GetInfoWithChecks(e.ctx, e.checks) if err != nil { log.Warnf("failed to get system info with checks: %v", err) @@ -1116,10 +1161,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { protoDNSConfig = &mgmProto.DNSConfig{} } - if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil { + dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network) + + if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil { log.Errorf("failed to update dns server, err: %v", err) } + e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort) + // apply routes first, route related actions might depend on routing being enabled routes := toRoutes(networkMap.GetRoutes()) serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes) @@ -1140,7 +1189,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes) - e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries, uint16(protoDNSConfig.ForwarderPort)) + e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries) // Ingress forward rules forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules()) @@ -1178,6 +1227,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.statusRecorder.FinishPeerListModifications() // update SSHServer by adding remote peer SSH keys + e.sshMux.Lock() if !isNil(e.sshServer) { for _, config := range networkMap.GetRemotePeers() { if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil { @@ -1188,6 +1238,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { } } } + e.sshMux.Unlock() } // must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store @@ -1264,10 +1315,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE } func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config { + forwarderPort := uint16(protoDNSConfig.GetForwarderPort()) + if forwarderPort == 0 { + forwarderPort = nbdns.ForwarderClientPort + } + dnsUpdate := nbdns.Config{ ServiceEnable: protoDNSConfig.GetServiceEnable(), CustomZones: make([]nbdns.CustomZone, 0), NameServerGroups: make([]*nbdns.NameServerGroup, 0), + ForwarderPort: forwarderPort, } for _, zone := range protoDNSConfig.GetCustomZones() { @@ -1424,7 +1481,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers func (e *Engine) receiveSignalEvents() { + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() // connect to a stream of messages coming from the signal server err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error { e.syncMsgMux.Lock() @@ -1541,12 +1600,14 @@ func (e *Engine) close() { e.statusRecorder.SetWgIface(nil) } + e.sshMux.Lock() if !isNil(e.sshServer) { err := e.sshServer.Stop() if err != nil { log.Warnf("failed stopping the SSH server: %v", err) } } + e.sshMux.Unlock() if e.firewall != nil { err := e.firewall.Close(e.stateManager) @@ -1723,7 +1784,7 @@ func (e *Engine) getRosenpassAddr() string { // RunHealthProbes executes health checks for Signal, Management, Relay and WireGuard services // and updates the status recorder with the latest states. -func (e *Engine) RunHealthProbes() bool { +func (e *Engine) RunHealthProbes(waitForResult bool) bool { e.syncMsgMux.Lock() signalHealthy := e.signal.IsHealthy() @@ -1755,8 +1816,12 @@ func (e *Engine) RunHealthProbes() bool { } e.syncMsgMux.Unlock() - - results := e.probeICE(stuns, turns) + var results []relay.ProbeResult + if waitForResult { + results = e.probeStunTurn.ProbeAllWaitResult(e.ctx, stuns, turns) + } else { + results = e.probeStunTurn.ProbeAll(e.ctx, stuns, turns) + } e.statusRecorder.UpdateRelayStates(results) relayHealthy := true @@ -1773,15 +1838,10 @@ func (e *Engine) RunHealthProbes() bool { return allHealthy } -func (e *Engine) probeICE(stuns, turns []*stun.URI) []relay.ProbeResult { - return append( - relay.ProbeAll(e.ctx, relay.ProbeSTUN, stuns), - relay.ProbeAll(e.ctx, relay.ProbeTURN, turns)..., - ) -} - -// restartEngine restarts the engine by cancelling the client context -func (e *Engine) restartEngine() { +// triggerClientRestart triggers a full client restart by cancelling the client context. +// Note: This does NOT just restart the engine - it cancels the entire client context, +// which causes the connect client's retry loop to create a completely new engine. +func (e *Engine) triggerClientRestart() { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() @@ -1803,7 +1863,9 @@ func (e *Engine) startNetworkMonitor() { } e.networkMonitor = networkmonitor.New() + e.shutdownWg.Add(1) go func() { + defer e.shutdownWg.Done() if err := e.networkMonitor.Listen(e.ctx); err != nil { if errors.Is(err, context.Canceled) { log.Infof("network monitor stopped") @@ -1813,8 +1875,8 @@ func (e *Engine) startNetworkMonitor() { return } - log.Infof("Network monitor: detected network change, restarting engine") - e.restartEngine() + log.Infof("Network monitor: detected network change, triggering client restart") + e.triggerClientRestart() }() } @@ -1899,64 +1961,50 @@ func (e *Engine) GetWgAddr() netip.Addr { func (e *Engine) updateDNSForwarder( enabled bool, fwdEntries []*dnsfwd.ForwarderEntry, - forwarderPort uint16, ) { if e.config.DisableServerRoutes { return } - if forwarderPort > 0 { - dnsfwd.SetListenPort(forwarderPort) - } - if !enabled { - if e.dnsForwardMgr == nil { - return - } - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } + e.stopDNSForwarder() return } if len(fwdEntries) > 0 { - switch { - case e.dnsForwardMgr == nil: - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) - if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { - log.Errorf("failed to start DNS forward: %v", err) - e.dnsForwardMgr = nil - } - log.Infof("started domain router service with %d entries", len(fwdEntries)) - case e.dnsFwdPort != forwarderPort: - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - e.restartDnsFwd(fwdEntries, forwarderPort) - e.dnsFwdPort = forwarderPort - - default: + if e.dnsForwardMgr == nil { + e.startDNSForwarder(fwdEntries) + } else { e.dnsForwardMgr.UpdateDomains(fwdEntries) } } else if e.dnsForwardMgr != nil { log.Infof("disable domain router service") - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = nil + e.stopDNSForwarder() } - } -func (e *Engine) restartDnsFwd(fwdEntries []*dnsfwd.ForwarderEntry, forwarderPort uint16) { - log.Infof("updating domain router service port from %d to %d", e.dnsFwdPort, forwarderPort) - // stop and start the forwarder to apply the new port - if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { - log.Errorf("failed to stop DNS forward: %v", err) - } - e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder) +func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) { + e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface) + if err := e.dnsForwardMgr.Start(fwdEntries); err != nil { log.Errorf("failed to start DNS forward: %v", err) e.dnsForwardMgr = nil + return } + + log.Infof("started domain router service with %d entries", len(fwdEntries)) +} + +func (e *Engine) stopDNSForwarder() { + if e.dnsForwardMgr == nil { + return + } + + if err := e.dnsForwardMgr.Stop(context.Background()); err != nil { + log.Errorf("failed to stop DNS forward: %v", err) + } + + e.dnsForwardMgr = nil } func (e *Engine) GetNet() (*netstack.Net, error) { diff --git a/client/internal/netflow/logger/logger.go b/client/internal/netflow/logger/logger.go index 899faf108..a033a2a7c 100644 --- a/client/internal/netflow/logger/logger.go +++ b/client/internal/netflow/logger/logger.go @@ -10,10 +10,10 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/netflow/store" "github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/peer" + "github.com/netbirdio/netbird/dns" ) type rcvChan chan *types.EventFields @@ -138,7 +138,8 @@ func (l *Logger) UpdateConfig(dnsCollection, exitNodeCollection bool) { func (l *Logger) shouldStore(event *types.EventFields, isExitNode bool) bool { // check dns collection - if !l.dnsCollection.Load() && event.Protocol == types.UDP && (event.DestPort == 53 || event.DestPort == uint16(dnsfwd.ListenPort())) { + if !l.dnsCollection.Load() && event.Protocol == types.UDP && + (event.DestPort == 53 || event.DestPort == dns.ForwarderClientPort || event.DestPort == dns.ForwarderServerPort) { return false } diff --git a/client/internal/netflow/manager.go b/client/internal/netflow/manager.go index e3b188468..7752c97b0 100644 --- a/client/internal/netflow/manager.go +++ b/client/internal/netflow/manager.go @@ -24,6 +24,7 @@ import ( // Manager handles netflow tracking and logging type Manager struct { mux sync.Mutex + shutdownWg sync.WaitGroup logger nftypes.FlowLogger flowConfig *nftypes.FlowConfig conntrack nftypes.ConnTracker @@ -105,8 +106,15 @@ func (m *Manager) resetClient() error { ctx, cancel := context.WithCancel(context.Background()) m.cancel = cancel - go m.receiveACKs(ctx, flowClient) - go m.startSender(ctx) + m.shutdownWg.Add(2) + go func() { + defer m.shutdownWg.Done() + m.receiveACKs(ctx, flowClient) + }() + go func() { + defer m.shutdownWg.Done() + m.startSender(ctx) + }() return nil } @@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error { // Close cleans up all resources func (m *Manager) Close() { m.mux.Lock() - defer m.mux.Unlock() - if err := m.disableFlow(); err != nil { log.Warnf("failed to disable flow manager: %v", err) } + m.mux.Unlock() + + m.shutdownWg.Wait() } // GetLogger returns the flow logger diff --git a/client/internal/networkmonitor/check_change_bsd.go b/client/internal/networkmonitor/check_change_bsd.go index f5eb2c739..b3482f54e 100644 --- a/client/internal/networkmonitor/check_change_bsd.go +++ b/client/internal/networkmonitor/check_change_bsd.go @@ -1,4 +1,4 @@ -//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd +//go:build dragonfly || freebsd || netbsd || openbsd package networkmonitor @@ -6,21 +6,19 @@ import ( "context" "errors" "fmt" - "syscall" - "unsafe" log "github.com/sirupsen/logrus" - "golang.org/x/net/route" "golang.org/x/sys/unix" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" ) func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + fd, err := prepareFd() if err != nil { return fmt.Errorf("open routing socket: %v", err) } + defer func() { err := unix.Close(fd) if err != nil && !errors.Is(err, unix.EBADF) { @@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er } }() - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - buf := make([]byte, 2048) - n, err := unix.Read(fd, buf) - if err != nil { - if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { - log.Warnf("Network monitor: failed to read from routing socket: %v", err) - } - continue - } - if n < unix.SizeofRtMsghdr { - log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) - continue - } - - msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - - switch msg.Type { - // handle route changes - case unix.RTM_ADD, syscall.RTM_DELETE: - route, err := parseRouteMessage(buf[:n]) - if err != nil { - log.Debugf("Network monitor: error parsing routing message: %v", err) - continue - } - - if route.Dst.Bits() != 0 { - continue - } - - intf := "" - if route.Interface != nil { - intf = route.Interface.Name - } - switch msg.Type { - case unix.RTM_ADD: - log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) - return nil - case unix.RTM_DELETE: - if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { - log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) - return nil - } - } - } - } - } -} - -func parseRouteMessage(buf []byte) (*systemops.Route, error) { - msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) - if err != nil { - return nil, fmt.Errorf("parse RIB: %v", err) - } - - if len(msgs) != 1 { - return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) - } - - msg, ok := msgs[0].(*route.RouteMessage) - if !ok { - return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) - } - - return systemops.MsgToRoute(msg) + return routeCheck(ctx, fd, nexthopv4, nexthopv6) } diff --git a/client/internal/networkmonitor/check_change_common.go b/client/internal/networkmonitor/check_change_common.go new file mode 100644 index 000000000..c287236e8 --- /dev/null +++ b/client/internal/networkmonitor/check_change_common.go @@ -0,0 +1,92 @@ +//go:build dragonfly || freebsd || netbsd || openbsd || darwin + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "syscall" + "unsafe" + + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +func prepareFd() (int, error) { + return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) +} + +func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + buf := make([]byte, 2048) + n, err := unix.Read(fd, buf) + if err != nil { + if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) { + log.Warnf("Network monitor: failed to read from routing socket: %v", err) + } + continue + } + if n < unix.SizeofRtMsghdr { + log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n) + continue + } + + msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) + + switch msg.Type { + // handle route changes + case unix.RTM_ADD, syscall.RTM_DELETE: + route, err := parseRouteMessage(buf[:n]) + if err != nil { + log.Debugf("Network monitor: error parsing routing message: %v", err) + continue + } + + if route.Dst.Bits() != 0 { + continue + } + + intf := "" + if route.Interface != nil { + intf = route.Interface.Name + } + switch msg.Type { + case unix.RTM_ADD: + log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf) + return nil + case unix.RTM_DELETE: + if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 { + log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf) + return nil + } + } + } + } + } +} + +func parseRouteMessage(buf []byte) (*systemops.Route, error) { + msgs, err := route.ParseRIB(route.RIBTypeRoute, buf) + if err != nil { + return nil, fmt.Errorf("parse RIB: %v", err) + } + + if len(msgs) != 1 { + return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs) + } + + msg, ok := msgs[0].(*route.RouteMessage) + if !ok { + return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0]) + } + + return systemops.MsgToRoute(msg) +} diff --git a/client/internal/networkmonitor/check_change_darwin.go b/client/internal/networkmonitor/check_change_darwin.go new file mode 100644 index 000000000..ddc6e1736 --- /dev/null +++ b/client/internal/networkmonitor/check_change_darwin.go @@ -0,0 +1,149 @@ +//go:build darwin && !ios + +package networkmonitor + +import ( + "context" + "errors" + "fmt" + "hash/fnv" + "os/exec" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" + + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" +) + +// todo: refactor to not use static functions + +func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error { + fd, err := prepareFd() + if err != nil { + return fmt.Errorf("open routing socket: %v", err) + } + + defer func() { + if err := unix.Close(fd); err != nil { + if !errors.Is(err, unix.EBADF) { + log.Warnf("Network monitor: failed to close routing socket: %v", err) + } + } + }() + + routeChanged := make(chan struct{}) + go func() { + _ = routeCheck(ctx, fd, nexthopv4, nexthopv6) + close(routeChanged) + }() + + wakeUp := make(chan struct{}) + go func() { + wakeUpListen(ctx) + close(wakeUp) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-routeChanged: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("route change detected") + return nil + case <-wakeUp: + if ctx.Err() != nil { + return ctx.Err() + } + log.Infof("wakeup detected") + return nil + } +} + +func wakeUpListen(ctx context.Context) { + log.Infof("start to watch for system wakeups") + var ( + initialHash uint32 + err error + ) + + // Keep retrying until initial sysctl succeeds or context is canceled + for { + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + default: + initialHash, err = readSleepTimeHash() + if err != nil { + log.Errorf("failed to detect initial sleep time: %v", err) + select { + case <-ctx.Done(): + log.Info("exit from wakeUpListen initial hash detection due to context cancellation") + return + case <-time.After(3 * time.Second): + continue + } + } + log.Debugf("initial wakeup hash: %d", initialHash) + break + } + break + } + + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + log.Info("context canceled, stopping wakeUpListen") + return + + case <-ticker.C: + newHash, err := readSleepTimeHash() + if err != nil { + log.Errorf("failed to read sleep time hash: %v", err) + continue + } + + if newHash == initialHash { + log.Tracef("no wakeup detected") + continue + } + + upOut, err := exec.Command("uptime").Output() + if err != nil { + log.Errorf("failed to run uptime command: %v", err) + upOut = []byte("unknown") + } + log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut) + return + } + } +} + +func readSleepTimeHash() (uint32, error) { + cmd := exec.Command("sysctl", "kern.sleeptime") + out, err := cmd.Output() + if err != nil { + return 0, fmt.Errorf("failed to run sysctl: %w", err) + } + + h, err := hash(out) + if err != nil { + return 0, fmt.Errorf("failed to compute hash: %w", err) + } + + return h, nil +} + +func hash(data []byte) (uint32, error) { + hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher + if _, err := hasher.Write(data); err != nil { + return 0, err + } + return hasher.Sum32(), nil +} diff --git a/client/internal/networkmonitor/monitor.go b/client/internal/networkmonitor/monitor.go index accdd9c9d..6d019258d 100644 --- a/client/internal/networkmonitor/monitor.go +++ b/client/internal/networkmonitor/monitor.go @@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { event := make(chan struct{}, 1) go nw.checkChanges(ctx, event, nexthop4, nexthop6) + log.Infof("start watching for network changes") // debounce changes timer := time.NewTimer(0) timer.Stop() diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 686430752..6f4f5ad4f 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -19,11 +19,10 @@ type SRWatcher struct { signalClient chNotifier relayManager chNotifier - listeners map[chan struct{}]struct{} - mu sync.Mutex - iFaceDiscover stdnet.ExternalIFaceDiscover - iceConfig ice.Config - + listeners map[chan struct{}]struct{} + mu sync.Mutex + iFaceDiscover stdnet.ExternalIFaceDiscover + iceConfig ice.Config cancelIceMonitor context.CancelFunc } diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index 3675f0157..5d8ebfe45 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) { if isController(w.config) { - return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) + return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } else { return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd) } diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go index a713bb342..23c92e8af 100644 --- a/client/internal/pkce_auth.go +++ b/client/internal/pkce_auth.go @@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct { DisablePromptLogin bool // LoginFlag is used to configure the PKCE flow login behavior LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it diff --git a/client/internal/relay/relay.go b/client/internal/relay/relay.go index fa208716f..693ea1f31 100644 --- a/client/internal/relay/relay.go +++ b/client/internal/relay/relay.go @@ -2,6 +2,8 @@ package relay import ( "context" + "crypto/sha256" + "errors" "fmt" "net" "sync" @@ -15,6 +17,15 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) +const ( + DefaultCacheTTL = 20 * time.Second + probeTimeout = 6 * time.Second +) + +var ( + ErrCheckInProgress = errors.New("probe check is already in progress") +) + // ProbeResult holds the info about the result of a relay probe request type ProbeResult struct { URI string @@ -22,8 +33,164 @@ type ProbeResult struct { Addr string } +type StunTurnProbe struct { + cacheResults []ProbeResult + cacheTimestamp time.Time + cacheKey string + cacheTTL time.Duration + probeInProgress bool + probeDone chan struct{} + mu sync.Mutex +} + +func NewStunTurnProbe(cacheTTL time.Duration) *StunTurnProbe { + return &StunTurnProbe{ + cacheTTL: cacheTTL, + } +} + +func (p *StunTurnProbe) ProbeAllWaitResult(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + if p.probeInProgress { + doneChan := p.probeDone + p.mu.Unlock() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-doneChan: + return p.getCachedResults(cacheKey, stuns, turns) + } + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + p.mu.Unlock() + + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + + return p.getCachedResults(cacheKey, stuns, turns) +} + +// ProbeAll probes all given servers asynchronously and returns the results +func (p *StunTurnProbe) ProbeAll(ctx context.Context, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + cacheKey := generateCacheKey(stuns, turns) + + p.mu.Lock() + + if results := p.checkCache(cacheKey); results != nil { + p.mu.Unlock() + return results + } + + if p.probeInProgress { + p.mu.Unlock() + return createErrorResults(stuns, turns) + } + + p.probeInProgress = true + probeDone := make(chan struct{}) + p.probeDone = probeDone + log.Infof("started new probe for STUN, TURN servers") + go func() { + p.doProbe(ctx, stuns, turns, cacheKey) + close(probeDone) + }() + + p.mu.Unlock() + + timer := time.NewTimer(1300 * time.Millisecond) + defer timer.Stop() + + select { + case <-ctx.Done(): + log.Debugf("Context cancelled while waiting for probe results") + return createErrorResults(stuns, turns) + case <-probeDone: + // when the probe is return fast, return the results right away + return p.getCachedResults(cacheKey, stuns, turns) + case <-timer.C: + // if the probe takes longer than 1.3s, return error results to avoid blocking + return createErrorResults(stuns, turns) + } +} + +func (p *StunTurnProbe) checkCache(cacheKey string) []ProbeResult { + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + age := time.Since(p.cacheTimestamp) + if age < p.cacheTTL { + results := append([]ProbeResult(nil), p.cacheResults...) + log.Debugf("returning cached probe results (age: %v)", age) + return results + } + } + return nil +} + +func (p *StunTurnProbe) getCachedResults(cacheKey string, stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + p.mu.Lock() + defer p.mu.Unlock() + + if p.cacheKey == cacheKey && len(p.cacheResults) > 0 { + return append([]ProbeResult(nil), p.cacheResults...) + } + return createErrorResults(stuns, turns) +} + +func (p *StunTurnProbe) doProbe(ctx context.Context, stuns []*stun.URI, turns []*stun.URI, cacheKey string) { + defer func() { + p.mu.Lock() + p.probeInProgress = false + p.mu.Unlock() + }() + results := make([]ProbeResult, len(stuns)+len(turns)) + + var wg sync.WaitGroup + for i, uri := range stuns { + wg.Add(1) + go func(idx int, stunURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = stunURI.String() + results[idx].Addr, results[idx].Err = p.probeSTUN(probeCtx, stunURI) + }(i, uri) + } + + stunOffset := len(stuns) + for i, uri := range turns { + wg.Add(1) + go func(idx int, turnURI *stun.URI) { + defer wg.Done() + + probeCtx, cancel := context.WithTimeout(ctx, probeTimeout) + defer cancel() + + results[idx].URI = turnURI.String() + results[idx].Addr, results[idx].Err = p.probeTURN(probeCtx, turnURI) + }(stunOffset+i, uri) + } + + wg.Wait() + + p.mu.Lock() + p.cacheResults = results + p.cacheTimestamp = time.Now() + p.cacheKey = cacheKey + p.mu.Unlock() + + log.Debug("Stored new probe results in cache") +} + // ProbeSTUN tries binding to the given STUN uri and acquiring an address -func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("stun probe error from %s: %s", uri, probeErr) @@ -83,7 +250,7 @@ func ProbeSTUN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) } // ProbeTURN tries allocating a session from the given TURN URI -func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { +func (p *StunTurnProbe) probeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) { defer func() { if probeErr != nil { log.Debugf("turn probe error from %s: %s", uri, probeErr) @@ -160,28 +327,28 @@ func ProbeTURN(ctx context.Context, uri *stun.URI) (addr string, probeErr error) return relayConn.LocalAddr().String(), nil } -// ProbeAll probes all given servers asynchronously and returns the results -func ProbeAll( - ctx context.Context, - fn func(ctx context.Context, uri *stun.URI) (addr string, probeErr error), - relays []*stun.URI, -) []ProbeResult { - results := make([]ProbeResult, len(relays)) +func createErrorResults(stuns []*stun.URI, turns []*stun.URI) []ProbeResult { + total := len(stuns) + len(turns) + results := make([]ProbeResult, total) - var wg sync.WaitGroup - for i, uri := range relays { - ctx, cancel := context.WithTimeout(ctx, 6*time.Second) - defer cancel() - - wg.Add(1) - go func(res *ProbeResult, stunURI *stun.URI) { - defer wg.Done() - res.URI = stunURI.String() - res.Addr, res.Err = fn(ctx, stunURI) - }(&results[i], uri) + allURIs := append(append([]*stun.URI{}, stuns...), turns...) + for i, uri := range allURIs { + results[i] = ProbeResult{ + URI: uri.String(), + Err: ErrCheckInProgress, + } } - wg.Wait() - return results } + +func generateCacheKey(stuns []*stun.URI, turns []*stun.URI) string { + h := sha256.New() + for _, uri := range stuns { + h.Write([]byte(uri.String())) + } + for _, uri := range turns { + h.Write([]byte(uri.String())) + } + return fmt.Sprintf("%x", h.Sum(nil)) +} diff --git a/client/internal/routemanager/common/params.go b/client/internal/routemanager/common/params.go index def18411f..8b5407850 100644 --- a/client/internal/routemanager/common/params.go +++ b/client/internal/routemanager/common/params.go @@ -1,6 +1,7 @@ package common import ( + "sync/atomic" "time" "github.com/netbirdio/netbird/client/firewall/manager" @@ -25,4 +26,5 @@ type HandlerParams struct { UseNewDNSRoute bool Firewall manager.Manager FakeIPManager *fakeip.Manager + ForwarderPort *atomic.Uint32 } diff --git a/client/internal/routemanager/dnsinterceptor/handler.go b/client/internal/routemanager/dnsinterceptor/handler.go index 47c2ffcda..348338dac 100644 --- a/client/internal/routemanager/dnsinterceptor/handler.go +++ b/client/internal/routemanager/dnsinterceptor/handler.go @@ -8,6 +8,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "time" "github.com/hashicorp/go-multierror" @@ -18,7 +19,6 @@ import ( firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" nbdns "github.com/netbirdio/netbird/client/internal/dns" - "github.com/netbirdio/netbird/client/internal/dnsfwd" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/common" @@ -55,6 +55,7 @@ type DnsInterceptor struct { peerStore *peerstore.Store firewall firewall.Manager fakeIPManager *fakeip.Manager + forwarderPort *atomic.Uint32 } func New(params common.HandlerParams) *DnsInterceptor { @@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor { firewall: params.Firewall, fakeIPManager: params.FakeIPManager, interceptedDomains: make(domainMap), + forwarderPort: params.ForwarderPort, } } @@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { r.MsgHdr.AuthenticatedData = true } - upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort()) + upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load())) ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 04513bbe4..26cf758d9 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -10,6 +10,7 @@ import ( "runtime" "slices" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -23,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" + nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peerstore" "github.com/netbirdio/netbird/client/internal/routemanager/client" @@ -54,6 +56,7 @@ type Manager interface { SetRouteChangeListener(listener listener.NetworkChangeListener) InitialRouteRange() []string SetFirewall(firewall.Manager) error + SetDNSForwarderPort(port uint16) Stop(stateManager *statemanager.Manager) } @@ -78,6 +81,7 @@ type DefaultManager struct { ctx context.Context stop context.CancelFunc mux sync.Mutex + shutdownWg sync.WaitGroup clientNetworks map[route.HAUniqueID]*client.Watcher routeSelector *routeselector.RouteSelector serverRouter *server.Router @@ -101,12 +105,13 @@ type DefaultManager struct { disableServerRoutes bool activeRoutes map[route.HAUniqueID]client.RouteHandler fakeIPManager *fakeip.Manager + dnsForwarderPort atomic.Uint32 } func NewManager(config ManagerConfig) *DefaultManager { mCTX, cancel := context.WithCancel(config.Context) notifier := notifier.NewNotifier() - sysOps := systemops.NewSysOps(config.WGInterface, notifier) + sysOps := systemops.New(config.WGInterface, notifier) if runtime.GOOS == "windows" && config.WGInterface != nil { nbnet.SetVPNInterfaceName(config.WGInterface.Name()) @@ -130,6 +135,7 @@ func NewManager(config ManagerConfig) *DefaultManager { disableServerRoutes: config.DisableServerRoutes, activeRoutes: make(map[route.HAUniqueID]client.RouteHandler), } + dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort)) useNoop := netstack.IsEnabled() || config.DisableClientRoutes dm.setupRefCounters(useNoop) @@ -270,9 +276,15 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error { return nil } +// SetDNSForwarderPort sets the DNS forwarder port for route handlers +func (m *DefaultManager) SetDNSForwarderPort(port uint16) { + m.dnsForwarderPort.Store(uint32(port)) +} + // Stop stops the manager watchers and clean firewall rules func (m *DefaultManager) Stop(stateManager *statemanager.Manager) { m.stop() + m.shutdownWg.Wait() if m.serverRouter != nil { m.serverRouter.CleanUp() } @@ -345,6 +357,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { UseNewDNSRoute: m.useNewDNSRoute, Firewall: m.firewall, FakeIPManager: m.fakeIPManager, + ForwarderPort: &m.dnsForwarderPort, } handler := client.HandlerFromRoute(params) if err := handler.AddRoute(m.ctx); err != nil { @@ -474,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) { } clientNetworkWatcher := client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes}) } @@ -516,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout } clientNetworkWatcher = client.NewWatcher(config) m.clientNetworks[id] = clientNetworkWatcher - go clientNetworkWatcher.Start() + m.shutdownWg.Add(1) + go func() { + defer m.shutdownWg.Done() + clientNetworkWatcher.Start() + }() } update := client.RoutesUpdate{ UpdateSerial: updateSerial, diff --git a/client/internal/routemanager/mock.go b/client/internal/routemanager/mock.go index be633c3fa..6b06144b2 100644 --- a/client/internal/routemanager/mock.go +++ b/client/internal/routemanager/mock.go @@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error { panic("implement me") } +// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface +func (m *MockManager) SetDNSForwarderPort(port uint16) { +} + // Stop mock implementation of Stop from Manager interface func (m *MockManager) Stop(stateManager *statemanager.Manager) { if m.StopFunc != nil { diff --git a/client/internal/routemanager/systemops/flush_nonbsd.go b/client/internal/routemanager/systemops/flush_nonbsd.go new file mode 100644 index 000000000..f1c45d6cf --- /dev/null +++ b/client/internal/routemanager/systemops/flush_nonbsd.go @@ -0,0 +1,8 @@ +//go:build !((darwin && !ios) || dragonfly || freebsd || netbsd || openbsd) + +package systemops + +// FlushMarkedRoutes is a no-op on non-BSD platforms. +func (r *SysOps) FlushMarkedRoutes() error { + return nil +} diff --git a/client/internal/routemanager/systemops/state.go b/client/internal/routemanager/systemops/state.go index 8e158711e..e0d045b07 100644 --- a/client/internal/routemanager/systemops/state.go +++ b/client/internal/routemanager/systemops/state.go @@ -13,11 +13,11 @@ func (s *ShutdownState) Name() string { } func (s *ShutdownState) Cleanup() error { - sysops := NewSysOps(nil, nil) - sysops.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysops.removeFromRouteTable) - sysops.refCounter.LoadData((*ExclusionCounter)(s)) + sysOps := New(nil, nil) + sysOps.refCounter = refcounter.New[netip.Prefix, struct{}, Nexthop](nil, sysOps.removeFromRouteTable) + sysOps.refCounter.LoadData((*ExclusionCounter)(s)) - return sysops.refCounter.Flush() + return sysOps.refCounter.Flush() } func (s *ShutdownState) MarshalJSON() ([]byte, error) { diff --git a/client/internal/routemanager/systemops/systemops.go b/client/internal/routemanager/systemops/systemops.go index 8da138117..c0ca21d22 100644 --- a/client/internal/routemanager/systemops/systemops.go +++ b/client/internal/routemanager/systemops/systemops.go @@ -83,7 +83,7 @@ type SysOps struct { localSubnetsCacheTime time.Time } -func NewSysOps(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { +func New(wgInterface wgIface, notifier *notifier.Notifier) *SysOps { return &SysOps{ wgInterface: wgInterface, notifier: notifier, diff --git a/client/internal/routemanager/systemops/systemops_bsd_test.go b/client/internal/routemanager/systemops/systemops_bsd_test.go index 0d892c162..ec4fc406e 100644 --- a/client/internal/routemanager/systemops/systemops_bsd_test.go +++ b/client/internal/routemanager/systemops/systemops_bsd_test.go @@ -42,7 +42,7 @@ func TestConcurrentRoutes(t *testing.T) { _, intf = setupDummyInterface(t) nexthop = Nexthop{netip.Addr{}, intf} - r := NewSysOps(nil, nil) + r := New(nil, nil) var wg sync.WaitGroup for i := 0; i < 1024; i++ { @@ -146,7 +146,7 @@ func createAndSetupDummyInterface(t *testing.T, intf string, ipAddressCIDR strin nexthop := Nexthop{netip.Addr{}, netIntf} - r := NewSysOps(nil, nil) + r := New(nil, nil) err = r.addToRouteTable(prefix, nexthop) require.NoError(t, err, "Failed to add route to table") diff --git a/client/internal/routemanager/systemops/systemops_generic_test.go b/client/internal/routemanager/systemops/systemops_generic_test.go index 32ea38a7a..d9b109beb 100644 --- a/client/internal/routemanager/systemops/systemops_generic_test.go +++ b/client/internal/routemanager/systemops/systemops_generic_test.go @@ -143,7 +143,7 @@ func TestAddVPNRoute(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun53%d", n), "100.65.75.2/24", 33100+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -342,7 +342,7 @@ func TestAddRouteToNonVPNIntf(t *testing.T) { wgInterface := createWGInterface(t, fmt.Sprintf("utun54%d", n), "100.65.75.2/24", 33200+n) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err) @@ -486,7 +486,7 @@ func setupTestEnv(t *testing.T) { assert.NoError(t, wgInterface.Close()) }) - r := NewSysOps(wgInterface, nil) + r := New(wgInterface, nil) advancedRouting := nbnet.AdvancedRouting() err := r.SetupRouting(nil, nil, advancedRouting) require.NoError(t, err, "setupRouting should not return err") diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index d43c2d5bf..7089178fb 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -7,19 +7,39 @@ import ( "fmt" "net" "net/netip" + "os" "strconv" "syscall" "time" "unsafe" "github.com/cenkalti/backoff/v4" + "github.com/hashicorp/go-multierror" log "github.com/sirupsen/logrus" "golang.org/x/net/route" "golang.org/x/sys/unix" + nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal/statemanager" ) +const ( + envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" +) + +var routeProtoFlag int + +func init() { + switch os.Getenv(envRouteProtoFlag) { + case "2": + routeProtoFlag = unix.RTF_PROTO2 + case "3": + routeProtoFlag = unix.RTF_PROTO3 + default: + routeProtoFlag = unix.RTF_PROTO1 + } +} + func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { return r.setupRefCounter(initAddresses, stateManager) } @@ -28,6 +48,62 @@ func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRout return r.cleanupRefCounter(stateManager) } +// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +func (r *SysOps) FlushMarkedRoutes() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + flushedCount := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + + routeInfo, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("Skipping route flush: %v", err) + continue + } + + if !routeInfo.Dst.IsValid() || !routeInfo.Dst.IsSingleIP() { + continue + } + + nexthop := Nexthop{ + IP: routeInfo.Gw, + Intf: routeInfo.Interface, + } + + if err := r.removeFromRouteTable(routeInfo.Dst, nexthop); err != nil { + merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", routeInfo.Dst, err)) + continue + } + + flushedCount++ + log.Debugf("Flushed marked route: %s", routeInfo.Dst) + } + + if flushedCount > 0 { + log.Infof("Flushed %d residual NetBird routes from previous session", flushedCount) + } + + return nberrors.FormatErrorOrNil(merr) +} + func (r *SysOps) addToRouteTable(prefix netip.Prefix, nexthop Nexthop) error { return r.routeSocket(unix.RTM_ADD, prefix, nexthop) } @@ -105,7 +181,7 @@ func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func( func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { msg = &route.RouteMessage{ Type: action, - Flags: unix.RTF_UP, + Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, Seq: r.getSeq(), } diff --git a/client/internal/routeselector/routeselector.go b/client/internal/routeselector/routeselector.go index e4a78599e..61c8bbc79 100644 --- a/client/internal/routeselector/routeselector.go +++ b/client/internal/routeselector/routeselector.go @@ -9,8 +9,6 @@ import ( "github.com/hashicorp/go-multierror" "golang.org/x/exp/maps" - log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/route" ) @@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool { defer rs.mu.RUnlock() if rs.deselectAll { - log.Debugf("Route %s not selected (deselect all)", routeID) return false } _, deselected := rs.deselectedRoutes[routeID] isSelected := !deselected - log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected) return isSelected } diff --git a/client/internal/statemanager/manager.go b/client/internal/statemanager/manager.go index 29f962ad2..2c9e46290 100644 --- a/client/internal/statemanager/manager.go +++ b/client/internal/statemanager/manager.go @@ -295,7 +295,7 @@ func (m *Manager) loadStateFile(deleteCorrupt bool) (map[string]json.RawMessage, data, err := os.ReadFile(m.filePath) if err != nil { if errors.Is(err, fs.ErrNotExist) { - log.Debug("state file does not exist") + log.Debugf("state file %s does not exist", m.filePath) return nil, nil // nolint:nilnil } return nil, fmt.Errorf("read state file: %w", err) diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 2109d4b15..fa1c89aab 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string { ConfigPath: c.cfgFile, }) - oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "") if err != nil { return err.Error() } diff --git a/client/net/conn.go b/client/net/conn.go index 918e7f628..bf54c792d 100644 --- a/client/net/conn.go +++ b/client/net/conn.go @@ -17,8 +17,7 @@ type Conn struct { ID hooks.ConnectionID } -// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection -// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection. func (c *Conn) Close() error { return closeConn(c.ID, c.Conn) } @@ -29,7 +28,7 @@ type TCPConn struct { ID hooks.ConnectionID } -// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection. func (c *TCPConn) Close() error { return closeConn(c.ID, c.TCPConn) } @@ -37,13 +36,16 @@ func (c *TCPConn) Close() error { // closeConn is a helper function to close connections and execute close hooks. func closeConn(id hooks.ConnectionID, conn io.Closer) error { err := conn.Close() + cleanupConnID(id) + return err +} +// cleanupConnID executes close hooks for a connection ID. +func cleanupConnID(id hooks.ConnectionID) { closeHooks := hooks.GetCloseHooks() for _, hook := range closeHooks { if err := hook(id); err != nil { log.Errorf("Error executing close hook: %v", err) } } - - return err } diff --git a/client/net/dial.go b/client/net/dial.go index 041a00e5d..17c9ff98a 100644 --- a/client/net/dial.go +++ b/client/net/dial.go @@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro } return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil } - if err := conn.Close(); err != nil { log.Errorf("failed to close connection: %v", err) } diff --git a/client/net/dialer_dial.go b/client/net/dialer_dial.go index 2e1eb53d8..1e275013f 100644 --- a/client/net/dialer_dial.go +++ b/client/net/dialer_dial.go @@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net. conn, err := d.Dialer.DialContext(ctx, network, address) if err != nil { + cleanupConnID(connID) return nil, fmt.Errorf("d.Dialer.DialContext: %w", err) } @@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str ips, err := resolver.LookupIPAddr(ctx, host) if err != nil { - return fmt.Errorf("failed to resolve address %s: %w", address, err) + return fmt.Errorf("resolve address %s: %w", address, err) } log.Debugf("Dialer resolved IPs for %s: %v", address, ips) diff --git a/client/net/listener_listen.go b/client/net/listener_listen.go index 0bb5ad67d..a150172b4 100644 --- a/client/net/listener_listen.go +++ b/client/net/listener_listen.go @@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.PacketConn.WriteTo(b, addr) } -// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection. func (c *PacketConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.PacketConn) @@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { return c.UDPConn.WriteTo(b, addr) } -// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection. +// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection. func (c *UDPConn) Close() error { defer c.seenAddrs.Clear() return closeConn(c.ID, c.UDPConn) diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 841e3c0f7..02f09b08a 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -279,8 +279,10 @@ type LoginRequest struct { ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // hint is used to pre-fill the email/username field during SSO authentication + Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { @@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 { return 0 } +func (x *LoginRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xc3\x0e\n" + + "\fEmptyRequest\"\xe5\x0e\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" + + "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" + "\x0e_block_inboundB\x0e\n" + "\f_profileNameB\v\n" + "\t_usernameB\x06\n" + - "\x04_mtu\"\xb5\x01\n" + + "\x04_mtuB\a\n" + + "\x05_hint\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 5b27b4d98..8d1080051 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -158,6 +158,9 @@ message LoginRequest { optional string username = 31; optional int64 mtu = 32; + + // hint is used to pre-fill the email/username field during SSO authentication + optional string hint = 33; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index 89f50a1ef..6699cdadc 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro state.Set(internal.StatusConnecting) if msg.SetupKey == "" { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient) + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err @@ -1057,10 +1061,7 @@ func (s *Server) Status( s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive) if msg.GetFullPeerStatus { - if msg.ShouldRunProbes { - s.runProbes() - } - + s.runProbes(msg.ShouldRunProbes) fullStatus := s.statusRecorder.GetFullStatus() pbFullStatus := toProtoFullStatus(fullStatus) pbFullStatus.Events = s.statusRecorder.GetEventHistory() @@ -1070,7 +1071,7 @@ func (s *Server) Status( return &statusResponse, nil } -func (s *Server) runProbes() { +func (s *Server) runProbes(waitForProbeResult bool) { if s.connectClient == nil { return } @@ -1081,7 +1082,7 @@ func (s *Server) runProbes() { } if time.Since(s.lastProbe) > probeThreshold { - if engine.RunHealthProbes() { + if engine.RunHealthProbes(waitForProbeResult) { s.lastProbe = time.Now() } } diff --git a/client/server/state.go b/client/server/state.go index 107f55154..1cf85cd37 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -10,7 +10,9 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -135,5 +137,12 @@ func restoreResidualState(ctx context.Context, statePath string) error { merr = multierror.Append(merr, fmt.Errorf("persist state: %w", err)) } + // clean up any remaining routes independently of the state file + if !nbnet.AdvancedRouting() { + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) + } + } + return nberrors.FormatErrorOrNil(merr) } diff --git a/client/status/status.go b/client/status/status.go index 5e4fcd8dc..8a0b7bae0 100644 --- a/client/status/status.go +++ b/client/status/status.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/internal/peer" + probeRelay "github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/version" @@ -340,10 +341,16 @@ func ParseGeneralSummary(overview OutputOverview, showURL bool, showRelays bool, for _, relay := range overview.Relays.Details { available := "Available" reason := "" + if !relay.Available { - available = "Unavailable" - reason = fmt.Sprintf(", reason: %s", relay.Error) + if relay.Error == probeRelay.ErrCheckInProgress.Error() { + available = "Checking..." + } else { + available = "Unavailable" + reason = fmt.Sprintf(", reason: %s", relay.Error) + } } + relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason) } } else { diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index 865dd2731..1cd4887ce 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -301,6 +301,8 @@ type serviceClient struct { wLoginURL fyne.Window wUpdateProgress fyne.Window updateContextCancel context.CancelFunc + + connectCancel context.CancelFunc } type menuHandler struct { @@ -624,17 +626,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form { } } -func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { +func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) { conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return nil, err + return nil, fmt.Errorf("get daemon client: %w", err) } activeProf, err := s.profileManager.GetActiveProfile() if err != nil { - log.Errorf("get active profile: %v", err) - return nil, err + return nil, fmt.Errorf("get active profile: %w", err) } currUser, err := user.Current() @@ -642,84 +642,80 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{ + loginReq := &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, - }) + } + + profileState, err := s.profileManager.GetProfileState(activeProf.Name) if err != nil { - log.Errorf("login to management URL with: %v", err) - return nil, err + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginReq.Hint = &profileState.Email + } + + loginResp, err := conn.Login(ctx, loginReq) + if err != nil { + return nil, fmt.Errorf("login to management: %w", err) } if loginResp.NeedsSSOLogin && openURL { - err = s.handleSSOLogin(loginResp, conn) - if err != nil { - log.Errorf("handle SSO login failed: %v", err) - return nil, err + if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil { + return nil, fmt.Errorf("SSO login: %w", err) } } return loginResp, nil } -func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { - err := openURL(loginResp.VerificationURIComplete) - if err != nil { - log.Errorf("opening the verification uri in the browser failed: %v", err) - return err +func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error { + if err := openURL(loginResp.VerificationURIComplete); err != nil { + return fmt.Errorf("open browser: %w", err) } - resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) + resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode}) if err != nil { - log.Errorf("waiting sso login failed with: %v", err) - return err + return fmt.Errorf("wait for SSO login: %w", err) } if resp.Email != "" { - err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ + if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{ Email: resp.Email, - }) - if err != nil { - log.Warnf("failed to set profile state: %v", err) + }); err != nil { + log.Debugf("failed to set profile state: %v", err) } else { s.mProfile.refresh() } - } return nil } -func (s *serviceClient) menuUpClick() error { +func (s *serviceClient) menuUpClick(ctx context.Context) error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { systray.SetTemplateIcon(iconErrorMacOS, s.icError) - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } - _, err = s.login(true) + _, err = s.login(ctx, true) if err != nil { - log.Errorf("login failed with: %v", err) - return err + return fmt.Errorf("login: %w", err) } - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status == string(internal.StatusConnected) { - log.Warnf("already connected") return nil } - if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil { - log.Errorf("up service: %v", err) - return err + if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil { + return fmt.Errorf("start connection: %w", err) } return nil @@ -729,24 +725,20 @@ func (s *serviceClient) menuDownClick() error { systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting) conn, err := s.getSrvClient(defaultFailTimeout) if err != nil { - log.Errorf("get client: %v", err) - return err + return fmt.Errorf("get daemon client: %w", err) } status, err := conn.Status(s.ctx, &proto.StatusRequest{}) if err != nil { - log.Errorf("get service status: %v", err) - return err + return fmt.Errorf("get status: %w", err) } if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) { - log.Warnf("already down") return nil } - if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil { - log.Errorf("down service: %v", err) - return err + if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil { + return fmt.Errorf("stop connection: %w", err) } return nil @@ -882,6 +874,7 @@ func (s *serviceClient) onTrayReady() { newProfileMenuArgs := &newProfileMenuArgs{ ctx: s.ctx, + serviceClient: s, profileManager: s.profileManager, eventHandler: s.eventHandler, profileMenuItem: profileMenuItem, @@ -1436,7 +1429,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - resp, err := s.login(false) + resp, err := s.login(ctx, false) if err != nil { log.Errorf("failed to fetch login URL: %v", err) return @@ -1456,7 +1449,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) + _, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode}) if err != nil { log.Errorf("Waiting sso login failed with: %v", err) label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.") @@ -1464,7 +1457,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { } label.SetText("Re-authentication successful.\nReconnecting") - status, err := conn.Status(s.ctx, &proto.StatusRequest{}) + status, err := conn.Status(ctx, &proto.StatusRequest{}) if err != nil { log.Errorf("get service status: %v", err) return @@ -1477,7 +1470,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc { return } - _, err = conn.Up(s.ctx, &proto.UpRequest{}) + _, err = conn.Up(ctx, &proto.UpRequest{}) if err != nil { label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") log.Errorf("Reconnecting failed with: %v", err) diff --git a/client/ui/event_handler.go b/client/ui/event_handler.go index e9b7f4f30..e0b619411 100644 --- a/client/ui/event_handler.go +++ b/client/ui/event_handler.go @@ -12,6 +12,8 @@ import ( "fyne.io/fyne/v2" "fyne.io/systray" log "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/version" @@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) { func (h *eventHandler) handleConnectClick() { h.client.mUp.Disable() + + if h.client.connectCancel != nil { + h.client.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(h.client.ctx) + h.client.connectCancel = connectCancel + go func() { - defer h.client.mUp.Enable() - if err := h.client.menuUpClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service")) + defer connectCancel() + + if err := h.client.menuUpClick(connectCtx); err != nil { + st, ok := status.FromError(err) + if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { + log.Debugf("connect operation cancelled by user") + } else { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect")) + log.Errorf("connect failed: %v", err) + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after connect: %v", err) } }() } func (h *eventHandler) handleDisconnectClick() { h.client.mDown.Disable() + + if h.client.connectCancel != nil { + log.Debugf("cancelling ongoing connect operation") + h.client.connectCancel() + h.client.connectCancel = nil + } + go func() { - defer h.client.mDown.Enable() if err := h.client.menuDownClick(); err != nil { - h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon")) + st, ok := status.FromError(err) + if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { + h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect")) + log.Errorf("disconnect failed: %v", err) + } else { + log.Debugf("disconnect cancelled or already disconnecting") + } + } + + if err := h.client.updateStatus(); err != nil { + log.Debugf("failed to update status after disconnect: %v", err) } }() } @@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error { } h.client.getSrvConfig() - + return nil } diff --git a/client/ui/profile.go b/client/ui/profile.go index 075223795..74189c9a0 100644 --- a/client/ui/profile.go +++ b/client/ui/profile.go @@ -387,6 +387,7 @@ type subItem struct { type profileMenu struct { mu sync.Mutex ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem @@ -396,7 +397,7 @@ type profileMenu struct { logoutSubItem *subItem profilesState []Profile downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -404,12 +405,13 @@ type profileMenu struct { type newProfileMenuArgs struct { ctx context.Context + serviceClient *serviceClient profileManager *profilemanager.ProfileManager eventHandler *eventHandler profileMenuItem *systray.MenuItem emailMenuItem *systray.MenuItem downClickCallback func() error - upClickCallback func() error + upClickCallback func(context.Context) error getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error) loadSettingsCallback func() app fyne.App @@ -418,6 +420,7 @@ type newProfileMenuArgs struct { func newProfileMenu(args newProfileMenuArgs) *profileMenu { p := profileMenu{ ctx: args.ctx, + serviceClient: args.serviceClient, profileManager: args.profileManager, eventHandler: args.eventHandler, profileMenuItem: args.profileMenuItem, @@ -569,10 +572,19 @@ func (p *profileMenu) refresh() { } } - if err := p.upClickCallback(); err != nil { + if p.serviceClient.connectCancel != nil { + p.serviceClient.connectCancel() + } + + connectCtx, connectCancel := context.WithCancel(p.ctx) + p.serviceClient.connectCancel = connectCancel + + if err := p.upClickCallback(connectCtx); err != nil { log.Errorf("failed to handle up click after switching profile: %v", err) } + connectCancel() + p.refresh() p.loadSettingsCallback() } diff --git a/dns/dns.go b/dns/dns.go index f889a32ec..cf089d4ed 100644 --- a/dns/dns.go +++ b/dns/dns.go @@ -19,6 +19,10 @@ const ( RootZone = "." // DefaultClass is the class supported by the system DefaultClass = "IN" + // ForwarderClientPort is the port clients connect to. DNAT rewrites packets from ForwarderClientPort to ForwarderServerPort. + ForwarderClientPort uint16 = 5353 + // ForwarderServerPort is the port the DNS forwarder actually listens on. Packets to ForwarderClientPort are DNATed here. + ForwarderServerPort uint16 = 22054 ) const invalidHostLabel = "[^a-zA-Z0-9-]+" @@ -31,6 +35,8 @@ type Config struct { NameServerGroups []*NameServerGroup // CustomZones contains a list of custom zone CustomZones []CustomZone + // ForwarderPort is the port clients should connect to on routing peers for DNS forwarding + ForwarderPort uint16 } // CustomZone represents a custom zone to be resolved by the dns server diff --git a/go.mod b/go.mod index 79dd92e6b..7b9bae321 100644 --- a/go.mod +++ b/go.mod @@ -56,13 +56,14 @@ require ( github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/go-secure-stdlib/base62 v0.1.2 github.com/hashicorp/go-version v1.6.0 + github.com/jackc/pgx/v5 v5.5.5 github.com/libdns/route53 v1.5.0 github.com/libp2p/go-netroute v0.2.1 github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/nadoo/ipset v0.5.0 - github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 + github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/okta/okta-sdk-golang/v2 v2.18.0 github.com/oschwald/maxminddb-golang v1.12.0 @@ -76,7 +77,7 @@ require ( github.com/pion/transport/v3 v3.0.7 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.22.0 - github.com/quic-go/quic-go v0.48.2 + github.com/quic-go/quic-go v0.49.1 github.com/redis/go-redis/v9 v9.7.3 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 @@ -102,11 +103,12 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a - golang.org/x/mod v0.25.0 + golang.org/x/mod v0.26.0 golang.org/x/net v0.42.0 - golang.org/x/oauth2 v0.28.0 + golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.16.0 golang.org/x/term v0.33.0 + golang.org/x/time v0.12.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -146,7 +148,7 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/caddyserver/zerossl v0.1.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/containerd/containerd v1.7.27 // indirect + github.com/containerd/containerd v1.7.29 // indirect github.com/containerd/log v0.1.0 // indirect github.com/containerd/platforms v0.2.1 // indirect github.com/cpuguy83/dockercfg v0.3.2 // indirect @@ -183,7 +185,6 @@ require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect github.com/jinzhu/inflection v1.0.0 // indirect @@ -241,11 +242,10 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect go.opentelemetry.io/otel/sdk v1.35.0 // indirect go.opentelemetry.io/otel/trace v1.35.0 // indirect - go.uber.org/mock v0.4.0 // indirect + go.uber.org/mock v0.5.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.18.0 // indirect golang.org/x/text v0.27.0 // indirect - golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.34.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect diff --git a/go.sum b/go.sum index f0065e081..61ad8740e 100644 --- a/go.sum +++ b/go.sum @@ -142,8 +142,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE= github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs= -github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII= -github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0= +github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE= +github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs= github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A= @@ -503,8 +503,8 @@ github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6S github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396 h1:aXHS63QWf0Z5fDN19Swl6npdJjGMyXthAvvgW7rbKJQ= -github.com/netbirdio/management-integrations/integrations v0.0.0-20251022080146-b1caade69396/go.mod h1:v0nUbbHbuQnqR7yKIYnKzsLBCswLtp2JctmKYmGgVhc= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48 h1:moJbL1uuaWR35yUgHZ6suijjqqW8/qGCuPPBXu5MeWQ= +github.com/netbirdio/management-integrations/integrations v0.0.0-20251027212525-d751b79f5d48/go.mod h1:ifKa2jGPsOzZhJFo72v2AE5nMP3GYvlhoZ9JV6lHlJ8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -590,8 +590,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= -github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE= -github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs= +github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0= +github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s= github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM= github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA= github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= @@ -749,8 +749,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8 go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= -go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= -go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= +go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= @@ -818,8 +818,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= -golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= +golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= +golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -880,8 +880,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc= -golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= +golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI= +golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -993,8 +993,8 @@ golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index daec4ef6f..209a20065 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -35,7 +35,13 @@ func (s *BaseServer) GeoLocationManager() geolocation.Geolocation { func (s *BaseServer) PermissionsManager() permissions.Manager { return Create(s, func() permissions.Manager { - return integrations.InitPermissionsManager(s.Store()) + manager := integrations.InitPermissionsManager(s.Store(), s.Metrics().GetMeter()) + + s.AfterInit(func(s *BaseServer) { + manager.SetAccountManager(s.AccountManager()) + }) + + return manager }) } diff --git a/management/main.go b/management/main.go index 561ed8f26..ff8482f97 100644 --- a/management/main.go +++ b/management/main.go @@ -1,11 +1,19 @@ package main import ( - "github.com/netbirdio/netbird/management/cmd" + "log" + "net/http" + // nolint:gosec + _ "net/http/pprof" "os" + + "github.com/netbirdio/netbird/management/cmd" ) func main() { + go func() { + log.Println(http.ListenAndServe("localhost:6060", nil)) + }() if err := cmd.Execute(); err != nil { os.Exit(1) } diff --git a/management/server/account.go b/management/server/account.go index fcdab4b69..1357915e0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -53,6 +53,9 @@ const ( peerSchedulerRetryInterval = 3 * time.Second emptyUserID = "empty user ID in claims" errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v" + + envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP" + envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS" ) type userLoggedInOnce bool @@ -109,6 +112,11 @@ type DefaultAccountManager struct { loginFilter *loginFilter disableDefaultPolicy bool + + holder *types.Holder + + expNewNetworkMap bool + expNewNetworkMapAIDs map[string]struct{} } func isUniqueConstraintError(err error) bool { @@ -196,6 +204,18 @@ func BuildManager( log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start)) }() + newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder)) + if err != nil { + log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err) + newNetworkMapBuilder = false + } + + ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",") + expIDs := make(map[string]struct{}, len(ids)) + for _, id := range ids { + expIDs[id] = struct{}{} + } + am := &DefaultAccountManager{ Store: store, geo: geo, @@ -217,6 +237,10 @@ func BuildManager( permissionsManager: permissionsManager, loginFilter: newLoginFilter(), disableDefaultPolicy: disableDefaultPolicy, + holder: types.NewHolder(), + + expNewNetworkMap: newNetworkMapBuilder, + expNewNetworkMapAIDs: expIDs, } am.startWarmup(ctx) @@ -397,6 +421,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco } if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } go am.UpdateAccountPeers(ctx, accountID) } @@ -1487,6 +1514,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth } if removedGroupAffectsPeers || newGroupsAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil { + return err + } + log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId) am.BufferUpdateAccountPeers(ctx, userAuth.AccountId) } @@ -1651,11 +1682,6 @@ func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) boo } func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start)) - }() - peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err) @@ -2139,6 +2165,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us } if updateNetworkMap { + peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID) + if err != nil { + return err + } + am.updatePeerInNetworkMapCache(peer.AccountID, peer) am.BufferUpdateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index fe9fb25c6..db377865a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -128,4 +128,5 @@ type Manager interface { GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) SetEphemeralManager(em ephemeral.Manager) AllowSync(string, uint64) bool + RecalculateNetworkMapCache(ctx context.Context, accountId string) error } diff --git a/management/server/account_test.go b/management/server/account_test.go index 07d2f2383..200ba6b98 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } +func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SaveGroup(t) +} + func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + testAccountManager_NetworkUpdates_SaveGroup(t) +} + +func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePolicy(t) +} + +func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { manager, account, peer1, _, _ := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_SavePolicy(t) +} + func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + testAccountManager_NetworkUpdates_SavePolicy(t) +} + +func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { manager, account, peer1, peer2, _ := setupNetworkMapTest(t) group := types.Group{ @@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeletePeer(t) +} + func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + testAccountManager_NetworkUpdates_DeletePeer(t) +} + +func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { manager, account, peer1, _, peer3 := setupNetworkMapTest(t) group := types.Group{ @@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { wg.Wait() } +func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + testAccountManager_NetworkUpdates_DeleteGroup(t) +} + +func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) @@ -1377,6 +1422,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { return } + for drained := false; !drained; { + select { + case <-updMsg: + default: + drained = true + } + } + wg := sync.WaitGroup{} wg.Add(1) go func() { @@ -1736,7 +1789,9 @@ func TestAccount_Copy(t *testing.T) { Address: "172.12.6.1/24", }, }, + NetworkMapCache: &types.NetworkMapBuilder{}, } + account.InitOnce() err := hasNilField(account) if err != nil { t.Fatal(err) diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index 80b165938..ffecb6b8f 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -7,6 +7,7 @@ import ( "path/filepath" "runtime" "strconv" + "time" log "github.com/sirupsen/logrus" "gorm.io/driver/postgres" @@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e return nil, err } - if storeEngine == types.SqliteStoreEngine { - sqlDB.SetMaxOpenConns(1) - } else { - conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) - if err != nil { - conns = runtime.NumCPU() - } - sqlDB.SetMaxOpenConns(conns) + conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv)) + if err != nil { + conns = runtime.NumCPU() } + if storeEngine == types.SqliteStoreEngine { + conns = 1 + } + + sqlDB.SetMaxOpenConns(conns) + sqlDB.SetMaxIdleConns(conns) + sqlDB.SetConnMaxLifetime(time.Hour) + sqlDB.SetConnMaxIdleTime(3 * time.Minute) + + log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) return db, nil } diff --git a/management/server/dns.go b/management/server/dns.go index 534f43ec6..decc5175d 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -21,8 +21,8 @@ import ( ) const ( - dnsForwarderPort = 22054 - oldForwarderPort = 5353 + dnsForwarderPort = nbdns.ForwarderServerPort + oldForwarderPort = nbdns.ForwarderClientPort ) const dnsForwarderPortMinVersion = "v0.59.0" @@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -196,7 +199,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID // If all peers have the required version, it returns the new well-known port (22054), otherwise returns 0. func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { if len(peers) == 0 { - return oldForwarderPort + return int64(oldForwarderPort) } reqVer := semver.Canonical(requiredVersion) @@ -211,17 +214,17 @@ func computeForwarderPort(peers []*nbpeer.Peer, requiredVersion string) int64 { peerVersion := semver.Canonical("v" + peer.Meta.WtVersion) if peerVersion == "" { // If any peer doesn't have version info, return 0 - return oldForwarderPort + return int64(oldForwarderPort) } // Compare versions if semver.Compare(peerVersion, reqVer) < 0 { - return oldForwarderPort + return int64(oldForwarderPort) } } // All peers have the required version or newer - return dnsForwarderPort + return int64(dnsForwarderPort) } // toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 83caf74ef..96f73a390 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -394,7 +394,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) @@ -402,7 +402,7 @@ func BenchmarkToProtocolDNSConfig(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { cache := &DNSConfigCache{} - toProtocolDNSConfig(testData, cache, dnsForwarderPort) + toProtocolDNSConfig(testData, cache, int64(dnsForwarderPort)) } }) } @@ -455,13 +455,13 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { } // First run with config1 - result1 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result1 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Second run with config2 - result2 := toProtocolDNSConfig(config2, &cache, dnsForwarderPort) + result2 := toProtocolDNSConfig(config2, &cache, int64(dnsForwarderPort)) // Third run with config1 again - result3 := toProtocolDNSConfig(config1, &cache, dnsForwarderPort) + result3 := toProtocolDNSConfig(config1, &cache, int64(dnsForwarderPort)) // Verify that result1 and result3 are identical if !reflect.DeepEqual(result1, result3) { @@ -486,7 +486,7 @@ func TestComputeForwarderPort(t *testing.T) { // Test with empty peers list peers := []*nbpeer.Peer{} result := computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for empty peers list, got %d", oldForwarderPort, result) } @@ -504,7 +504,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with old versions, got %d", oldForwarderPort, result) } @@ -522,7 +522,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != dnsForwarderPort { + if result != int64(dnsForwarderPort) { t.Errorf("Expected %d for peers with new versions, got %d", dnsForwarderPort, result) } @@ -540,7 +540,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with mixed versions, got %d", oldForwarderPort, result) } @@ -553,7 +553,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with empty version, got %d", oldForwarderPort, result) } @@ -565,7 +565,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result == oldForwarderPort { + if result == int64(oldForwarderPort) { t.Errorf("Expected %d for peers with dev version, got %d", dnsForwarderPort, result) } @@ -578,7 +578,7 @@ func TestComputeForwarderPort(t *testing.T) { }, } result = computeForwarderPort(peers, "v0.59.0") - if result != oldForwarderPort { + if result != int64(oldForwarderPort) { t.Errorf("Expected %d for peers with unknown version, got %d", oldForwarderPort, result) } } diff --git a/management/server/group.go b/management/server/group.go index 487cb6d97..3cf9290a2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -138,6 +141,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use return err } + newGroup.AccountID = accountID + + events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) + eventsToStore = append(eventsToStore, events...) + oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID) if err != nil { return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID) @@ -157,11 +165,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } } - newGroup.AccountID = accountID - - events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup) - eventsToStore = append(eventsToStore, events...) - updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID}) if err != nil { return err @@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -335,6 +347,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac if err == nil && oldGroup != nil { addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) + + if oldGroup.Name != newGroup.Name { + eventsToStore = append(eventsToStore, func() { + meta := map[string]any{ + "old_name": oldGroup.Name, + "new_name": newGroup.Name, + } + am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta) + }) + } } else { addedPeers = append(addedPeers, newGroup.Peers...) eventsToStore = append(eventsToStore, func() { @@ -471,6 +493,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -509,6 +534,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -537,6 +565,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -575,6 +606,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun } if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 58a8dcd8e..2f0c32821 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -7,8 +7,10 @@ import ( "net" "net/netip" "os" + "strconv" "strings" "sync" + "sync/atomic" "time" pb "github.com/golang/protobuf/proto" // nolint @@ -44,6 +46,9 @@ import ( const ( envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS" envBlockPeers = "NB_BLOCK_SAME_PEERS" + envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS" + + defaultSyncLim = 1000 ) // GRPCServer an instance of a Management gRPC API server @@ -63,6 +68,9 @@ type GRPCServer struct { logBlockedPeers bool blockPeersWithSameConfig bool integratedPeerValidator integrated_validator.IntegratedValidator + + syncSem atomic.Int32 + syncLim int32 } // NewServer creates a new Management server @@ -96,6 +104,17 @@ func NewServer( logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true" blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" + syncLim := int32(defaultSyncLim) + if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { + syncLimParsed, err := strconv.Atoi(syncLimStr) + if err != nil { + log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim) + } else { + //nolint:gosec + syncLim = int32(syncLimParsed) + } + } + return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -110,6 +129,8 @@ func NewServer( logBlockedPeers: logBlockedPeers, blockPeersWithSameConfig: blockPeersWithSameConfig, integratedPeerValidator: integratedPeerValidator, + + syncLim: syncLim, }, nil } @@ -151,6 +172,11 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { + if s.syncSem.Load() >= s.syncLim { + return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") + } + s.syncSem.Add(1) + reqStart := time.Now() ctx := srv.Context() @@ -158,6 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi syncReq := &proto.SyncRequest{} peerKey, err := s.parseRequest(ctx, req, syncReq) if err != nil { + s.syncSem.Add(-1) return err } realIP := getRealIP(ctx) @@ -172,6 +199,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed) } if s.blockPeersWithSameConfig { + s.syncSem.Add(-1) return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn) } } @@ -183,27 +211,34 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String()) - unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) - defer func() { - if unlock != nil { - unlock() - } - }() - accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()) if err != nil { // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN") log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String()) if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound { + s.syncSem.Add(-1) return status.Errorf(codes.PermissionDenied, "peer is not registered") } + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + // nolint:staticcheck ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + start := time.Now() + unlock := s.acquirePeerLockByUID(ctx, peerKey.String()) + defer func() { + if unlock != nil { + unlock() + } + }() + log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start)) + log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart)) + log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP) if syncReq.GetMeta() == nil { @@ -213,21 +248,32 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return mapError(ctx, err) } + log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart)) + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) + s.syncSem.Add(-1) return err } + log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart)) updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID) + log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart)) + s.ephemeralManager.OnPeerConnected(ctx, peer) + log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart)) + s.secretsManager.SetupRefresh(ctx, accountID, peer.ID) + log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart)) + if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID) } @@ -237,6 +283,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart)) + s.syncSem.Add(-1) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) } @@ -509,10 +557,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p //nolint ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) + log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart)) + defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID) } + took := time.Since(reqStart) + if took > 7*time.Second { + log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart)) + } }() if loginReq.GetMeta() == nil { @@ -546,9 +600,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, mapError(ctx, err) } + log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart)) + // if the login request contains setup key then it is a registration request if loginReq.GetSetupKey() != "" { s.ephemeralManager.OnPeerDisconnected(ctx, peer) + log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart)) } loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks) @@ -557,6 +614,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p return nil, status.Errorf(codes.Internal, "failed logging in peer") } + log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart)) + encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID) @@ -826,10 +885,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p return status.Errorf(codes.Internal, "error handling request") } + sendStart := time.Now() err = srv.Send(&proto.EncryptedMessage{ WgPubKey: s.wgKey.PublicKey().String(), Body: encryptedResp, }) + log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart)) if err != nil { log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err) diff --git a/management/server/holder.go b/management/server/holder.go new file mode 100644 index 000000000..e8a26e1d0 --- /dev/null +++ b/management/server/holder.go @@ -0,0 +1,39 @@ +package server + +import ( + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) { + a := am.holder.GetAccount(account.Id) + if a == nil { + am.holder.AddAccount(account) + return + } + account.NetworkMapCache = a.NetworkMapCache + if account.NetworkMapCache == nil { + return + } + account.NetworkMapCache.UpdateAccountPointer(account) + am.holder.AddAccount(account) +} + +func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { + return am.holder.GetAccount(accountID) +} + +func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account { + a := am.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure) + if err != nil { + return nil + } + return account +} + +func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { + am.holder.AddAccount(account) +} diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 3d4de31d0..4d2c224b4 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -4,9 +4,13 @@ import ( "context" "fmt" "net/http" + "os" + "strconv" + "time" "github.com/gorilla/mux" "github.com/rs/cors" + log "github.com/sirupsen/logrus" "github.com/netbirdio/management-integrations/integrations" @@ -38,7 +42,12 @@ import ( "github.com/netbirdio/netbird/management/server/telemetry" ) -const apiPrefix = "/api" +const ( + apiPrefix = "/api" + rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" + rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" + rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" +) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. func NewAPIHandler( @@ -58,11 +67,42 @@ func NewAPIHandler( settingsManager settings.Manager, ) (http.Handler, error) { + var rateLimitingConfig *middleware.RateLimiterConfig + if os.Getenv(rateLimitingEnabledKey) == "true" { + rpm := 6 + if v := os.Getenv(rateLimitingRPMKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) + } else { + rpm = value + } + } + + burst := 500 + if v := os.Getenv(rateLimitingBurstKey); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) + } else { + burst = value + } + } + + rateLimitingConfig = &middleware.RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + } + } + authMiddleware := middleware.NewAuthMiddleware( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, + rateLimitingConfig, ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 6091a4c31..bce917a25 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -29,6 +29,7 @@ type AuthMiddleware struct { ensureAccount EnsureAccountFunc getUserFromUserAuth GetUserFromUserAuthFunc syncUserJWTGroups SyncUserJWTGroupsFunc + rateLimiter *APIRateLimiter } // NewAuthMiddleware instance constructor @@ -37,12 +38,19 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, + rateLimiterConfig *RateLimiterConfig, ) *AuthMiddleware { + var rateLimiter *APIRateLimiter + if rateLimiterConfig != nil { + rateLimiter = NewAPIRateLimiter(rateLimiterConfig) + } + return &AuthMiddleware{ authManager: authManager, ensureAccount: ensureAccount, syncUserJWTGroups: syncUserJWTGroups, getUserFromUserAuth: getUserFromUserAuth, + rateLimiter: rateLimiter, } } @@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { request, err := m.checkPATFromRequest(r, auth) if err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) - util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) + // Check if it's a status error, otherwise default to Unauthorized + if _, ok := status.FromError(err); !ok { + err = status.Errorf(status.Unauthorized, "token invalid") + } + util.WriteError(r.Context(), err, w) return } h.ServeHTTP(w, request) @@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h return r, fmt.Errorf("error extracting token: %w", err) } + if m.rateLimiter != nil { + if !m.rateLimiter.Allow(token) { + return r, status.Errorf(status.TooManyRequests, "too many requests") + } + } + ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index d815f5422..d1bd9959f 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -27,7 +27,9 @@ const ( domainCategory = "domainCategory" userID = "userID" tokenID = "tokenID" + tokenID2 = "tokenID2" PAT = "nbp_PAT" + PAT2 = "nbp_PAT2" JWT = "JWT" wrongToken = "wrongToken" ) @@ -49,6 +51,15 @@ var testAccount = &types.Account{ CreatedAt: time.Now().UTC(), LastUsed: util.ToPtr(time.Now().UTC()), }, + tokenID2: { + ID: tokenID2, + Name: "My second token", + HashedToken: "someHash2", + ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)), + CreatedBy: userID, + CreatedAt: time.Now().UTC(), + LastUsed: util.ToPtr(time.Now().UTC()), + }, }, }, }, @@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use if token == PAT { return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil } + if token == PAT2 { + return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil + } return nil, nil, "", "", fmt.Errorf("PAT invalid") } @@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA } func mockMarkPATUsed(_ context.Context, token string) error { - if token == tokenID { + if token == tokenID || token == tokenID2 { return nil } return fmt.Errorf("Should never get reached") @@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) { } } +func TestAuthMiddleware_RateLimiting(t *testing.T) { + mockAuth := &auth.MockManager{ + ValidateAndParseTokenFunc: mockValidateAndParseToken, + EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups, + MarkPATUsedFunc: mockMarkPATUsed, + GetPATInfoFunc: mockGetAccountInfoFromPAT, + } + + t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) { + // Configure rate limiter: 10 requests per minute with burst of 5 + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 10, + Burst: 5, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make burst requests - all should succeed + successCount := 0 + for i := 0; i < 5; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 5, successCount, "All burst requests should succeed") + + // The 6th request should fail (exceeded burst) + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited") + }) + + t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) { + // Configure very low rate limit: 1 request per minute + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request should fail (rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + }) + + t.Run("Bearer Token Not Rate Limited", func(t *testing.T) { + // Configure strict rate limit + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Make multiple requests with Bearer token - all should succeed + successCount := 0 + for i := 0; i < 10; i++ { + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + if rec.Code == http.StatusOK { + successCount++ + } + } + + assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)") + }) + + t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) { + // Configure rate limiter + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 1, + Burst: 1, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // Use first PAT token + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed") + + // Second request with same token should fail + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited") + + // Use second PAT token - should succeed because it has independent rate limit + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)") + + // Second request with PAT2 should also be rate limited + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT2) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited") + + // JWT should still work (not rate limited) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Bearer "+JWT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)") + }) + + t.Run("Rate Limiter Cleanup", func(t *testing.T) { + // Configure rate limiter with short cleanup interval and TTL for testing + rateLimitConfig := &RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: 100 * time.Millisecond, + LimiterTTL: 200 * time.Millisecond, + } + + authMiddleware := NewAuthMiddleware( + mockAuth, + func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { + return userAuth.AccountId, userAuth.UserId, nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) error { + return nil + }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, + rateLimitConfig, + ) + + handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request - should succeed + req := httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed") + + // Second request immediately - should fail (burst exhausted) + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited") + + // Wait for limiter to be cleaned up (TTL + cleanup interval + buffer) + time.Sleep(400 * time.Millisecond) + + // After cleanup, the limiter should be removed and recreated with full burst capacity + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)") + + // Verify it's a fresh limiter by checking burst is reset + req = httptest.NewRequest("GET", "http://testing/test", nil) + req.Header.Set("Authorization", "Token "+PAT) + rec = httptest.NewRecorder() + handler.ServeHTTP(rec, req) + assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again") + }) +} + func TestAuthMiddleware_Handler_Child(t *testing.T) { tt := []struct { name string @@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { return &types.User{}, nil }, + nil, ) for _, tc := range tt { diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go new file mode 100644 index 000000000..a6266d4f3 --- /dev/null +++ b/management/server/http/middleware/rate_limiter.go @@ -0,0 +1,146 @@ +package middleware + +import ( + "context" + "sync" + "time" + + "golang.org/x/time/rate" +) + +// RateLimiterConfig holds configuration for the API rate limiter +type RateLimiterConfig struct { + // RequestsPerMinute defines the rate at which tokens are replenished + RequestsPerMinute float64 + // Burst defines the maximum number of requests that can be made in a burst + Burst int + // CleanupInterval defines how often to clean up old limiters (how often garbage collection runs) + CleanupInterval time.Duration + // LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal) + LimiterTTL time.Duration +} + +// DefaultRateLimiterConfig returns a default configuration +func DefaultRateLimiterConfig() *RateLimiterConfig { + return &RateLimiterConfig{ + RequestsPerMinute: 100, + Burst: 120, + CleanupInterval: 5 * time.Minute, + LimiterTTL: 10 * time.Minute, + } +} + +// limiterEntry holds a rate limiter and its last access time +type limiterEntry struct { + limiter *rate.Limiter + lastAccess time.Time +} + +// APIRateLimiter manages rate limiting for API tokens +type APIRateLimiter struct { + config *RateLimiterConfig + limiters map[string]*limiterEntry + mu sync.RWMutex + stopChan chan struct{} +} + +// NewAPIRateLimiter creates a new API rate limiter with the given configuration +func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { + if config == nil { + config = DefaultRateLimiterConfig() + } + + rl := &APIRateLimiter{ + config: config, + limiters: make(map[string]*limiterEntry), + stopChan: make(chan struct{}), + } + + go rl.cleanupLoop() + + return rl +} + +// Allow checks if a request for the given key (token) is allowed +func (rl *APIRateLimiter) Allow(key string) bool { + limiter := rl.getLimiter(key) + return limiter.Allow() +} + +// Wait blocks until the rate limiter allows another request for the given key +// Returns an error if the context is canceled +func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + limiter := rl.getLimiter(key) + return limiter.Wait(ctx) +} + +// getLimiter retrieves or creates a rate limiter for the given key +func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.RLock() + entry, exists := rl.limiters[key] + rl.mu.RUnlock() + + if exists { + rl.mu.Lock() + entry.lastAccess = time.Now() + rl.mu.Unlock() + return entry.limiter + } + + rl.mu.Lock() + defer rl.mu.Unlock() + + if entry, exists := rl.limiters[key]; exists { + entry.lastAccess = time.Now() + return entry.limiter + } + + requestsPerSecond := rl.config.RequestsPerMinute / 60.0 + limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst) + rl.limiters[key] = &limiterEntry{ + limiter: limiter, + lastAccess: time.Now(), + } + + return limiter +} + +// cleanupLoop periodically removes old limiters that haven't been used recently +func (rl *APIRateLimiter) cleanupLoop() { + ticker := time.NewTicker(rl.config.CleanupInterval) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + rl.cleanup() + case <-rl.stopChan: + return + } + } +} + +// cleanup removes limiters that haven't been used within the TTL period +func (rl *APIRateLimiter) cleanup() { + rl.mu.Lock() + defer rl.mu.Unlock() + + now := time.Now() + for key, entry := range rl.limiters { + if now.Sub(entry.lastAccess) > rl.config.LimiterTTL { + delete(rl.limiters, key) + } + } +} + +// Stop stops the cleanup goroutine +func (rl *APIRateLimiter) Stop() { + close(rl.stopChan) +} + +// Reset removes the rate limiter for a specific key +func (rl *APIRateLimiter) Reset(key string) { + rl.mu.Lock() + defer rl.mu.Unlock() + delete(rl.limiters, key) +} diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 741f03f18..bdf56db6e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -7,9 +7,10 @@ import ( "time" "github.com/golang-jwt/jwt/v5" - "github.com/netbirdio/management-integrations/integrations" "github.com/stretchr/testify/assert" + "github.com/netbirdio/management-integrations/integrations" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index e87043f26..8baffa58b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -125,9 +125,10 @@ type MockAccountManager struct { UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - AllowSyncFunc func(string, uint64) bool - UpdateAccountPeersFunc func(ctx context.Context, accountID string) - BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + AllowSyncFunc func(string, uint64) bool + UpdateAccountPeersFunc func(ctx context.Context, accountID string) + BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string) + RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error } func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error { @@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool { } return true } + +func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error { + if am.RecalculateNetworkMapCacheFunc != nil { + return am.RecalculateNetworkMapCacheFunc(ctx, accountID) + } + return nil +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index f278e1761..ee77a65bb 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/networkmap.go b/management/server/networkmap.go new file mode 100644 index 000000000..2a0627643 --- /dev/null +++ b/management/server/networkmap.go @@ -0,0 +1,80 @@ +package server + +import ( + "context" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/management/server/types" +) + +func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) { + am.enrichAccountFromHolder(account) + account.InitNetworkMapBuilderIfNeeded(validatedPeers) +} + +func (am *DefaultAccountManager) getPeerNetworkMapExp( + ctx context.Context, + accountId string, + peerId string, + validatedPeers map[string]struct{}, + customZone nbdns.CustomZone, + metrics *telemetry.AccountManagerMetrics, +) *types.NetworkMap { + account := am.getAccountFromHolderOrInit(accountId) + if account == nil { + log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId) + return &types.NetworkMap{ + Network: &types.Network{}, + } + } + return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) +} + +func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerAddedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error { + am.enrichAccountFromHolder(account) + return account.OnPeerDeletedUpdNetworkMapCache(peerId) +} + +func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) { + account := am.getAccountFromHolder(accountId) + if account == nil { + return + } + account.UpdatePeerInNetworkMapCache(peer) +} + +func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) { + account.RecalculateNetworkMapCache(validatedPeers) + am.updateAccountInHolder(account) +} + +func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { + if am.experimentalNetworkMap(accountId) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + if err != nil { + return err + } + validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) + if err != nil { + log.WithContext(ctx).Errorf("failed to get validate peers: %v", err) + return err + } + am.recalculateNetworkMapCache(account, validatedPeers) + } + return nil +} + +func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool { + _, ok := am.expNewNetworkMapAIDs[accountId] + return am.expNewNetworkMap || ok +} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index b6706ca45..0e6d1631b 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 66484d120..b740610c2 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID) return resource, nil @@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net event() } + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 82cac424a..89ac419fd 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network)) + if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil { + return nil, err + } go m.accountManager.UpdateAccountPeers(ctx, router.AccountID) return router, nil @@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo event() + if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } go m.accountManager.UpdateAccountPeers(ctx, accountID) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 276a06b1a..9475ac55e 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -106,11 +106,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start)) - }() - var peer *nbpeer.Peer var settings *types.Settings var expired bool @@ -145,6 +140,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. am.BufferUpdateAccountPeers(ctx, accountID) @@ -321,6 +319,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + if peerLabelChanged || requiresPeerUpdates { am.UpdateAccountPeers(ctx, accountID) } else if sshChanged { @@ -381,6 +383,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return err + } + + if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err) + } + + } + if userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } @@ -417,7 +431,13 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin return nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(peer.AccountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -690,6 +710,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + if am.experimentalNetworkMap(accountID) { + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } + + if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil { + log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) + } + } + am.BufferUpdateAccountPeers(ctx, accountID) return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) @@ -708,11 +739,6 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) { // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start)) - }() - var peer *nbpeer.Peer var peerNotValid bool var isStatusChanged bool @@ -776,6 +802,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } am.BufferUpdateAccountPeers(ctx, accountID) } @@ -831,6 +860,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + startTransaction := time.Now() err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { @@ -900,8 +930,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer return nil, nil, nil, err } + log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction)) + if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } + startBuffer := time.Now() am.BufferUpdateAccountPeers(ctx, accountID) + log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer)) } return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) @@ -997,11 +1034,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co } func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { - start := time.Now() - defer func() { - log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start)) - }() - if isRequiresApproval { network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -1014,9 +1046,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return peer, emptyMap, nil, nil } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } } approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) @@ -1024,10 +1064,12 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } + startPosture := time.Now() postureChecks, err := am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } + log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture)) customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings)) @@ -1037,7 +1079,13 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + var networkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1167,11 +1215,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun // Should be called when changes have to be synced to peers. func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) - return + var ( + account *types.Account + err error + ) + if am.experimentalNetworkMap(accountID) { + account = am.getAccountFromHolderOrInit(accountID) + } else { + account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) + return + } } globalStart := time.Now() @@ -1204,6 +1259,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account resourcePolicies := account.GetResourcePoliciesMap() routers := account.GetResourceRoutersMap() + if am.experimentalNetworkMap(accountID) { + am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap) + } + proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) if err != nil { log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) @@ -1241,7 +1300,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start)) start = time.Now() - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountID) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start)) start = time.Now() @@ -1257,7 +1322,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) - am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update}) }(peer) } @@ -1351,7 +1416,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI return } - remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + var remotePeerNetworkMap *types.NetworkMap + + if am.experimentalNetworkMap(accountId) { + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + } else { + remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) + } proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] if ok { @@ -1368,7 +1439,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion) update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) + am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) } // getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found. @@ -1581,7 +1652,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto }, }, }, - NetworkMap: &types.NetworkMap{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) peerDeletedEvents = append(peerDeletedEvents, func() { diff --git a/management/server/peer_test.go b/management/server/peer_test.go index fd795b926..e151f5abb 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) { } func TestAccountManager_GetNetworkMap(t *testing.T) { + testGetNetworkMapGeneral(t) +} + +func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testGetNetworkMapGeneral(t) +} + +func testGetNetworkMapGeneral(t *testing.T) { manager, err := createManager(t) if err != nil { t.Fatal(err) @@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } } +func TestUpdateAccountPeers_Experimental(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") + testUpdateAccountPeers(t) +} + func TestUpdateAccountPeers(t *testing.T) { + testUpdateAccountPeers(t) +} + +func testUpdateAccountPeers(t *testing.T) { testCases := []struct { name string peers int @@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) { for _, channel := range peerChannels { update := <-channel assert.Nil(t, update.Update.NetbirdConfig) - assert.Equal(t, tc.peers, len(update.NetworkMap.Peers)) - assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules)) + assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers)) + assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules)) } }) } @@ -1161,7 +1179,7 @@ func TestToSyncResponse(t *testing.T) { } dnsCache := &DNSConfigCache{} accountSettings := &types.Settings{RoutingPeerDNSResolutionEnabled: true} - response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, dnsForwarderPort) + response := toSyncResponse(context.Background(), config, peer, turnRelayToken, turnRelayToken, networkMap, dnsName, checks, dnsCache, accountSettings, nil, []string{}, int64(dnsForwarderPort)) assert.NotNil(t, response) // assert peer config @@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { } func Test_LoginPeer(t *testing.T) { + t.Setenv(envNewNetworkMapBuilder, "true") if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") } diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index 891fa59bb..e6bdd2025 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -7,6 +7,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -22,6 +23,7 @@ type Manager interface { ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error GetPermissionsByRole(ctx context.Context, role types.UserRole) (roles.Permissions, error) + SetAccountManager(accountManager account.Manager) } type managerImpl struct { @@ -121,3 +123,7 @@ func (m *managerImpl) GetPermissionsByRole(ctx context.Context, role types.UserR return permissions, nil } + +func (m *managerImpl) SetAccountManager(accountManager account.Manager) { + // no-op +} diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index fa115d628..ec9f263f9 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -9,6 +9,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + account "github.com/netbirdio/netbird/management/server/account" modules "github.com/netbirdio/netbird/management/server/permissions/modules" operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" @@ -53,6 +54,18 @@ func (mr *MockManagerMockRecorder) GetPermissionsByRole(ctx, role interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPermissionsByRole", reflect.TypeOf((*MockManager)(nil).GetPermissionsByRole), ctx, role) } +// SetAccountManager mocks base method. +func (m *MockManager) SetAccountManager(accountManager account.Manager) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccountManager", accountManager) +} + +// SetAccountManager indicates an expected call of SetAccountManager. +func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager) +} + // ValidateAccountAccess mocks base method. func (m *MockManager) ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error { m.ctrl.T.Helper() diff --git a/management/server/policy.go b/management/server/policy.go index 9e4b3f73a..ff02d46aa 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/policy_test.go b/management/server/policy_test.go index 4a08f4c33..97ebbcf5a 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) { expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", Direction: types.FirewallRuleDirectionIN, Action: "accept", Protocol: "all", @@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) { PolicyID: "RuleDefault", }, { - PeerIP: "0.0.0.0", + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.254.139", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.250.202", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionIN, + Action: "accept", + Protocol: "all", + Port: "", + PolicyID: "RuleDefault", + }, + { + PeerIP: "100.65.29.55", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "all", @@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) { // We expect a single permissive firewall rule which all outgoing connections peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) - assert.Len(t, firewallRules, 1) + assert.Len(t, firewallRules, 7) expectedFirewallRules := []*types.FirewallRule{ { - PeerIP: "0.0.0.0", + PeerIP: "100.65.80.39", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.14.88", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.62.5", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.32.206", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.13.186", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.29.55", + Direction: types.FirewallRuleDirectionOUT, + Action: "accept", + Protocol: "tcp", + Port: "80", + PolicyID: "RuleSwarm", + }, + { + PeerIP: "100.65.21.56", Direction: types.FirewallRuleDirectionOUT, Action: "accept", Protocol: "tcp", diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 943f2a970..f457b994b 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/route.go b/management/server/route.go index 4510426bb..05f7acf9e 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return nil, err + } am.UpdateAccountPeers(ctx, accountID) } @@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } @@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { + if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil { + return err + } am.UpdateAccountPeers(ctx, accountID) } diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 2b2896572..f16b609f8 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -5,6 +5,9 @@ package settings import ( "context" "fmt" + "time" + + log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/integrations/extra_settings" @@ -45,6 +48,11 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager { } func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) { + start := time.Now() + defer func() { + log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start)) + }() + if userID != activity.SystemInitiator { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read) if err != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 382d026c8..94b7fc1cc 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -15,6 +16,8 @@ import ( "sync" "time" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "gorm.io/driver/mysql" "gorm.io/driver/postgres" @@ -46,6 +49,11 @@ const ( accountAndIDsQueryCondition = "account_id = ? AND id IN ?" accountIDCondition = "account_id = ?" peerNotFoundFMT = "peer %s not found" + + pgMaxConnections = 30 + pgMinConnections = 1 + pgMaxConnLifetime = 60 * time.Minute + pgHealthCheckPeriod = 1 * time.Minute ) // SqlStore represents an account storage backed by a Sql DB persisted to disk @@ -55,6 +63,7 @@ type SqlStore struct { metrics telemetry.AppMetrics installationPK int storeEngine types.Engine + pool *pgxpool.Pool } type installation struct { @@ -76,12 +85,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met conns = runtime.NumCPU() } - switch storeEngine { - case types.MysqlStoreEngine: - if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil { - return nil, err - } - case types.SqliteStoreEngine: + if storeEngine == types.SqliteStoreEngine { if err == nil { log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } @@ -89,8 +93,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met } sql.SetMaxOpenConns(conns) + sql.SetMaxIdleConns(conns) + sql.SetConnMaxLifetime(time.Hour) + sql.SetConnMaxIdleTime(3 * time.Minute) - log.WithContext(ctx).Infof("Set max open db connections to %d", conns) + log.WithContext(ctx).Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v", + conns, conns, time.Hour, 3*time.Minute) if skipMigration { log.WithContext(ctx).Infof("skipping migration") @@ -162,7 +170,7 @@ func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) erro group.StoreGroupPeers() } - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -257,7 +265,7 @@ func (s *SqlStore) checkAccountDomainBeforeSave(ctx context.Context, accountID, func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) error { start := time.Now() - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Select(clause.Associations).Delete(account.Policies, "account_id = ?", account.Id) if result.Error != nil { return result.Error @@ -307,7 +315,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer. peerCopy := peer.Copy() peerCopy.AccountID = accountID - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { // check if peer exists before saving var peerID string result := tx.Model(&nbpeer.Peer{}).Select("id").Take(&peerID, accountAndIDQueryCondition, accountID, peer.ID) @@ -596,7 +604,7 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre } func (s *SqlStore) DeleteUser(ctx context.Context, accountID, userID string) error { - err := s.db.Transaction(func(tx *gorm.DB) error { + err := s.transaction(func(tx *gorm.DB) error { result := tx.Delete(&types.PersonalAccessToken{}, "user_id = ?", userID) if result.Error != nil { return result.Error @@ -774,6 +782,13 @@ func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types. } func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountPgx(ctx, accountID) + } + return s.getAccountGorm(ctx, accountID) +} + +func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { elapsed := time.Since(start) @@ -784,9 +799,19 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc var account types.Account result := s.db.Model(&account). - Omit("GroupsG"). - Preload("UsersG.PATsG"). // have to be specifies as this is nester reference - Preload(clause.Associations). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). Take(&account, idQueryCondition, accountID) if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) @@ -796,70 +821,1147 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc return nil, status.NewGetAccountFromStoreError(result.Error) } - // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us - for i, policy := range account.Policies { - var rules []*types.PolicyRule - err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error - if err != nil { - return nil, status.Errorf(status.NotFound, "rule not found") - } - account.Policies[i].Rules = rules - } - account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) for _, key := range account.SetupKeysG { - account.SetupKeys[key.Key] = key.Copy() + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key } account.SetupKeysG = nil account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) for _, peer := range account.PeersG { - account.Peers[peer.ID] = peer.Copy() + account.Peers[peer.ID] = &peer } account.PeersG = nil - account.Users = make(map[string]*types.User, len(account.UsersG)) for _, user := range account.UsersG { user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) for _, pat := range user.PATsG { - user.PATs[pat.ID] = pat.Copy() + pat.UserID = "" + user.PATs[pat.ID] = &pat } - account.Users[user.Id] = user.Copy() + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil } account.UsersG = nil - account.Groups = make(map[string]*types.Group, len(account.GroupsG)) for _, group := range account.GroupsG { - account.Groups[group.ID] = group.Copy() + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group } account.GroupsG = nil - var groupPeers []types.GroupPeer - s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). - Find(&groupPeers) - for _, groupPeer := range groupPeers { - if group, ok := account.Groups[groupPeer.GroupID]; ok { - group.Peers = append(group.Peers, groupPeer.PeerID) - } else { - log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + account.InitOnce() + return &account, nil +} + +func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e } } - account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) - for _, route := range account.RoutesG { - account.Routes[route.ID] = route.Copy() + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route } - account.RoutesG = nil account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) - for _, ns := range account.NameServerGroupsG { - account.NameServerGroups[ns.ID] = ns.Copy() + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil account.NameServerGroupsG = nil + return account, nil +} + +func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { + var account types.Account + account.Network = &types.Network{} + const accountQuery = ` + SELECT + id, created_by, created_at, domain, domain_category, is_domain_primary_account, + -- Embedded Network + network_identifier, network_net, network_dns, network_serial, + -- Embedded DNSSettings + dns_settings_disabled_management_groups, + -- Embedded Settings + settings_peer_login_expiration_enabled, settings_peer_login_expiration, + settings_peer_inactivity_expiration_enabled, settings_peer_inactivity_expiration, + settings_regular_users_view_blocked, settings_groups_propagation_enabled, + settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups, + settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range, + settings_lazy_connection_enabled, + -- Embedded ExtraSettings + settings_extra_peer_approval_enabled, settings_extra_user_approval_required, + settings_extra_integrated_validator, settings_extra_integrated_validator_groups + FROM accounts WHERE id = $1` + + var ( + sPeerLoginExpirationEnabled sql.NullBool + sPeerLoginExpiration sql.NullInt64 + sPeerInactivityExpirationEnabled sql.NullBool + sPeerInactivityExpiration sql.NullInt64 + sRegularUsersViewBlocked sql.NullBool + sGroupsPropagationEnabled sql.NullBool + sJWTGroupsEnabled sql.NullBool + sJWTGroupsClaimName sql.NullString + sJWTAllowGroups sql.NullString + sRoutingPeerDNSResolutionEnabled sql.NullBool + sDNSDomain sql.NullString + sNetworkRange sql.NullString + sLazyConnectionEnabled sql.NullBool + sExtraPeerApprovalEnabled sql.NullBool + sExtraUserApprovalRequired sql.NullBool + sExtraIntegratedValidator sql.NullString + sExtraIntegratedValidatorGroups sql.NullString + networkNet sql.NullString + dnsSettingsDisabledGroups sql.NullString + networkIdentifier sql.NullString + networkDns sql.NullString + networkSerial sql.NullInt64 + createdAt sql.NullTime + ) + err := s.pool.QueryRow(ctx, accountQuery, accountID).Scan( + &account.Id, &account.CreatedBy, &createdAt, &account.Domain, &account.DomainCategory, &account.IsDomainPrimaryAccount, + &networkIdentifier, &networkNet, &networkDns, &networkSerial, + &dnsSettingsDisabledGroups, + &sPeerLoginExpirationEnabled, &sPeerLoginExpiration, + &sPeerInactivityExpirationEnabled, &sPeerInactivityExpiration, + &sRegularUsersViewBlocked, &sGroupsPropagationEnabled, + &sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups, + &sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange, + &sLazyConnectionEnabled, + &sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired, + &sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(err) + } + + account.Settings = &types.Settings{Extra: &types.ExtraSettings{}} + if networkNet.Valid { + _ = json.Unmarshal([]byte(networkNet.String), &account.Network.Net) + } + if createdAt.Valid { + account.CreatedAt = createdAt.Time + } + if dnsSettingsDisabledGroups.Valid { + _ = json.Unmarshal([]byte(dnsSettingsDisabledGroups.String), &account.DNSSettings.DisabledManagementGroups) + } + if networkIdentifier.Valid { + account.Network.Identifier = networkIdentifier.String + } + if networkDns.Valid { + account.Network.Dns = networkDns.String + } + if networkSerial.Valid { + account.Network.Serial = uint64(networkSerial.Int64) + } + if sPeerLoginExpirationEnabled.Valid { + account.Settings.PeerLoginExpirationEnabled = sPeerLoginExpirationEnabled.Bool + } + if sPeerLoginExpiration.Valid { + account.Settings.PeerLoginExpiration = time.Duration(sPeerLoginExpiration.Int64) + } + if sPeerInactivityExpirationEnabled.Valid { + account.Settings.PeerInactivityExpirationEnabled = sPeerInactivityExpirationEnabled.Bool + } + if sPeerInactivityExpiration.Valid { + account.Settings.PeerInactivityExpiration = time.Duration(sPeerInactivityExpiration.Int64) + } + if sRegularUsersViewBlocked.Valid { + account.Settings.RegularUsersViewBlocked = sRegularUsersViewBlocked.Bool + } + if sGroupsPropagationEnabled.Valid { + account.Settings.GroupsPropagationEnabled = sGroupsPropagationEnabled.Bool + } + if sJWTGroupsEnabled.Valid { + account.Settings.JWTGroupsEnabled = sJWTGroupsEnabled.Bool + } + if sJWTGroupsClaimName.Valid { + account.Settings.JWTGroupsClaimName = sJWTGroupsClaimName.String + } + if sRoutingPeerDNSResolutionEnabled.Valid { + account.Settings.RoutingPeerDNSResolutionEnabled = sRoutingPeerDNSResolutionEnabled.Bool + } + if sDNSDomain.Valid { + account.Settings.DNSDomain = sDNSDomain.String + } + if sLazyConnectionEnabled.Valid { + account.Settings.LazyConnectionEnabled = sLazyConnectionEnabled.Bool + } + if sJWTAllowGroups.Valid { + _ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups) + } + if sNetworkRange.Valid { + _ = json.Unmarshal([]byte(sNetworkRange.String), &account.Settings.NetworkRange) + } + + if sExtraPeerApprovalEnabled.Valid { + account.Settings.Extra.PeerApprovalEnabled = sExtraPeerApprovalEnabled.Bool + } + if sExtraUserApprovalRequired.Valid { + account.Settings.Extra.UserApprovalRequired = sExtraUserApprovalRequired.Bool + } + if sExtraIntegratedValidator.Valid { + account.Settings.Extra.IntegratedValidator = sExtraIntegratedValidator.String + } + if sExtraIntegratedValidatorGroups.Valid { + _ = json.Unmarshal([]byte(sExtraIntegratedValidatorGroups.String), &account.Settings.Extra.IntegratedValidatorGroups) + } + account.InitOnce() return &account, nil } +func (s *SqlStore) getSetupKeys(ctx context.Context, accountID string) ([]types.SetupKey, error) { + const query = `SELECT id, account_id, key, key_secret, name, type, created_at, expires_at, updated_at, + revoked, used_times, last_used, auto_groups, usage_limit, ephemeral, allow_extra_dns_labels FROM setup_keys WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + keys, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.SetupKey, error) { + var sk types.SetupKey + var autoGroups []byte + var skCreatedAt, expiresAt, updatedAt, lastUsed sql.NullTime + var revoked, ephemeral, allowExtraDNSLabels sql.NullBool + var usedTimes, usageLimit sql.NullInt64 + + err := row.Scan(&sk.Id, &sk.AccountID, &sk.Key, &sk.KeySecret, &sk.Name, &sk.Type, &skCreatedAt, + &expiresAt, &updatedAt, &revoked, &usedTimes, &lastUsed, &autoGroups, &usageLimit, &ephemeral, &allowExtraDNSLabels) + + if err == nil { + if expiresAt.Valid { + sk.ExpiresAt = &expiresAt.Time + } + if skCreatedAt.Valid { + sk.CreatedAt = skCreatedAt.Time + } + if updatedAt.Valid { + sk.UpdatedAt = updatedAt.Time + if sk.UpdatedAt.IsZero() { + sk.UpdatedAt = sk.CreatedAt + } + } + if lastUsed.Valid { + sk.LastUsed = &lastUsed.Time + } + if revoked.Valid { + sk.Revoked = revoked.Bool + } + if usedTimes.Valid { + sk.UsedTimes = int(usedTimes.Int64) + } + if usageLimit.Valid { + sk.UsageLimit = int(usageLimit.Int64) + } + if ephemeral.Valid { + sk.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + sk.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &sk.AutoGroups) + } else { + sk.AutoGroups = []string{} + } + } + return sk, err + }) + if err != nil { + return nil, err + } + return keys, nil +} + +func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Peer, error) { + const query = `SELECT id, account_id, key, ip, name, dns_label, user_id, ssh_key, ssh_enabled, login_expiration_enabled, + inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, + meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, + meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, + meta_environment, meta_flags, meta_files, peer_status_last_seen, peer_status_connected, peer_status_login_expired, + peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, + location_geo_name_id FROM peers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + + peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { + var p nbpeer.Peer + p.Status = &nbpeer.PeerStatus{} + var ( + lastLogin, createdAt sql.NullTime + sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + peerStatusLastSeen sql.NullTime + peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool + ip, extraDNS, netAddr, env, flags, files, connIP []byte + metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString + metaOS, metaOSVersion, metaWtVersion, metaUIVersion, metaKernelVersion sql.NullString + metaSystemSerialNumber, metaSystemProductName, metaSystemManufacturer sql.NullString + locationCountryCode, locationCityName sql.NullString + locationGeoNameID sql.NullInt64 + ) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, + &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &createdAt, &ephemeral, &extraDNS, + &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, + &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, + &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, + &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, + &locationCountryCode, &locationCityName, &locationGeoNameID) + + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + p.CreatedAt = createdAt.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if metaHostname.Valid { + p.Meta.Hostname = metaHostname.String + } + if metaGoOS.Valid { + p.Meta.GoOS = metaGoOS.String + } + if metaKernel.Valid { + p.Meta.Kernel = metaKernel.String + } + if metaCore.Valid { + p.Meta.Core = metaCore.String + } + if metaPlatform.Valid { + p.Meta.Platform = metaPlatform.String + } + if metaOS.Valid { + p.Meta.OS = metaOS.String + } + if metaOSVersion.Valid { + p.Meta.OSVersion = metaOSVersion.String + } + if metaWtVersion.Valid { + p.Meta.WtVersion = metaWtVersion.String + } + if metaUIVersion.Valid { + p.Meta.UIVersion = metaUIVersion.String + } + if metaKernelVersion.Valid { + p.Meta.KernelVersion = metaKernelVersion.String + } + if metaSystemSerialNumber.Valid { + p.Meta.SystemSerialNumber = metaSystemSerialNumber.String + } + if metaSystemProductName.Valid { + p.Meta.SystemProductName = metaSystemProductName.String + } + if metaSystemManufacturer.Valid { + p.Meta.SystemManufacturer = metaSystemManufacturer.String + } + if locationCountryCode.Valid { + p.Location.CountryCode = locationCountryCode.String + } + if locationCityName.Valid { + p.Location.CityName = locationCityName.String + } + if locationGeoNameID.Valid { + p.Location.GeoNameID = uint(locationGeoNameID.Int64) + } + if ip != nil { + _ = json.Unmarshal(ip, &p.IP) + } + if extraDNS != nil { + _ = json.Unmarshal(extraDNS, &p.ExtraDNSLabels) + } + if netAddr != nil { + _ = json.Unmarshal(netAddr, &p.Meta.NetworkAddresses) + } + if env != nil { + _ = json.Unmarshal(env, &p.Meta.Environment) + } + if flags != nil { + _ = json.Unmarshal(flags, &p.Meta.Flags) + } + if files != nil { + _ = json.Unmarshal(files, &p.Meta.Files) + } + if connIP != nil { + _ = json.Unmarshal(connIP, &p.Location.ConnectionIP) + } + } + return p, err + }) + if err != nil { + return nil, err + } + return peers, nil +} + +func (s *SqlStore) getUsers(ctx context.Context, accountID string) ([]types.User, error) { + const query = `SELECT id, account_id, role, is_service_user, non_deletable, service_user_name, auto_groups, blocked, pending_approval, last_login, created_at, issued, integration_ref_id, integration_ref_integration_type FROM users WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + users, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.User, error) { + var u types.User + var autoGroups []byte + var lastLogin, createdAt sql.NullTime + var isServiceUser, nonDeletable, blocked, pendingApproval sql.NullBool + err := row.Scan(&u.Id, &u.AccountID, &u.Role, &isServiceUser, &nonDeletable, &u.ServiceUserName, &autoGroups, &blocked, &pendingApproval, &lastLogin, &createdAt, &u.Issued, &u.IntegrationReference.ID, &u.IntegrationReference.IntegrationType) + if err == nil { + if lastLogin.Valid { + u.LastLogin = &lastLogin.Time + } + if createdAt.Valid { + u.CreatedAt = createdAt.Time + } + if isServiceUser.Valid { + u.IsServiceUser = isServiceUser.Bool + } + if nonDeletable.Valid { + u.NonDeletable = nonDeletable.Bool + } + if blocked.Valid { + u.Blocked = blocked.Bool + } + if pendingApproval.Valid { + u.PendingApproval = pendingApproval.Bool + } + if autoGroups != nil { + _ = json.Unmarshal(autoGroups, &u.AutoGroups) + } else { + u.AutoGroups = []string{} + } + } + return u, err + }) + if err != nil { + return nil, err + } + return users, nil +} + +func (s *SqlStore) getGroups(ctx context.Context, accountID string) ([]*types.Group, error) { + const query = `SELECT id, account_id, name, issued, resources, integration_ref_id, integration_ref_integration_type FROM groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + groups, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Group, error) { + var g types.Group + var resources []byte + var refID sql.NullInt64 + var refType sql.NullString + err := row.Scan(&g.ID, &g.AccountID, &g.Name, &g.Issued, &resources, &refID, &refType) + if err == nil { + if refID.Valid { + g.IntegrationReference.ID = int(refID.Int64) + } + if refType.Valid { + g.IntegrationReference.IntegrationType = refType.String + } + if resources != nil { + _ = json.Unmarshal(resources, &g.Resources) + } else { + g.Resources = []types.Resource{} + } + g.GroupPeers = []types.GroupPeer{} + g.Peers = []string{} + } + return &g, err + }) + if err != nil { + return nil, err + } + return groups, nil +} + +func (s *SqlStore) getPolicies(ctx context.Context, accountID string) ([]*types.Policy, error) { + const query = `SELECT id, account_id, name, description, enabled, source_posture_checks FROM policies WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { + var p types.Policy + var checks []byte + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } + } + return &p, err + }) + if err != nil { + return nil, err + } + return policies, nil +} + +func (s *SqlStore) getRoutes(ctx context.Context, accountID string) ([]route.Route, error) { + const query = `SELECT id, account_id, network, domains, keep_route, net_id, description, peer, peer_groups, network_type, masquerade, metric, enabled, groups, access_control_groups, skip_auto_apply FROM routes WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { + var r route.Route + var network, domains, peerGroups, groups, accessGroups []byte + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) + if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if network != nil { + _ = json.Unmarshal(network, &r.Network) + } + if domains != nil { + _ = json.Unmarshal(domains, &r.Domains) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + if groups != nil { + _ = json.Unmarshal(groups, &r.Groups) + } + if accessGroups != nil { + _ = json.Unmarshal(accessGroups, &r.AccessControlGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + return routes, nil +} + +func (s *SqlStore) getNameServerGroups(ctx context.Context, accountID string) ([]nbdns.NameServerGroup, error) { + const query = `SELECT id, account_id, name, description, name_servers, groups, "primary", domains, enabled, search_domains_enabled FROM name_server_groups WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { + var n nbdns.NameServerGroup + var ns, groups, domains []byte + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) + if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } + if ns != nil { + _ = json.Unmarshal(ns, &n.NameServers) + } else { + n.NameServers = []nbdns.NameServer{} + } + if groups != nil { + _ = json.Unmarshal(groups, &n.Groups) + } else { + n.Groups = []string{} + } + if domains != nil { + _ = json.Unmarshal(domains, &n.Domains) + } else { + n.Domains = []string{} + } + } + return n, err + }) + if err != nil { + return nil, err + } + return nsgs, nil +} + +func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + const query = `SELECT id, account_id, name, description, checks FROM posture_checks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + checks, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*posture.Checks, error) { + var c posture.Checks + var checksDef []byte + err := row.Scan(&c.ID, &c.AccountID, &c.Name, &c.Description, &checksDef) + if err == nil && checksDef != nil { + _ = json.Unmarshal(checksDef, &c.Checks) + } + return &c, err + }) + if err != nil { + return nil, err + } + return checks, nil +} + +func (s *SqlStore) getNetworks(ctx context.Context, accountID string) ([]*networkTypes.Network, error) { + const query = `SELECT id, account_id, name, description FROM networks WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + networks, err := pgx.CollectRows(rows, pgx.RowToStructByName[networkTypes.Network]) + if err != nil { + return nil, err + } + result := make([]*networkTypes.Network, len(networks)) + for i := range networks { + result[i] = &networks[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkRouters(ctx context.Context, accountID string) ([]*routerTypes.NetworkRouter, error) { + const query = `SELECT id, network_id, account_id, peer, peer_groups, masquerade, metric, enabled FROM network_routers WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { + var r routerTypes.NetworkRouter + var peerGroups []byte + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + result[i] = &routers[i] + } + return result, nil +} + +func (s *SqlStore) getNetworkResources(ctx context.Context, accountID string) ([]*resourceTypes.NetworkResource, error) { + const query = `SELECT id, network_id, account_id, name, description, type, domain, prefix, enabled FROM network_resources WHERE account_id = $1` + rows, err := s.pool.Query(ctx, query, accountID) + if err != nil { + return nil, err + } + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { + var r resourceTypes.NetworkResource + var prefix []byte + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } + } + return r, err + }) + if err != nil { + return nil, err + } + result := make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + result[i] = &resources[i] + } + return result, nil +} + +func (s *SqlStore) getAccountOnboarding(ctx context.Context, accountID string, account *types.Account) error { + const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool + var createdAt, updatedAt sql.NullTime + err := s.pool.QueryRow(ctx, query, accountID).Scan( + &account.Onboarding.AccountID, + &onboardingFlowPending, + &signupFormPending, + &createdAt, + &updatedAt, + ) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return err + } + if createdAt.Valid { + account.Onboarding.CreatedAt = createdAt.Time + } + if updatedAt.Valid { + account.Onboarding.UpdatedAt = updatedAt.Time + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool + } + return nil +} + +func (s *SqlStore) getPersonalAccessTokens(ctx context.Context, userIDs []string) ([]types.PersonalAccessToken, error) { + if len(userIDs) == 0 { + return nil, nil + } + const query = `SELECT id, user_id, name, hashed_token, expiration_date, created_by, created_at, last_used FROM personal_access_tokens WHERE user_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, userIDs) + if err != nil { + return nil, err + } + pats, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed, createdAt sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &createdAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if createdAt.Valid { + pat.CreatedAt = createdAt.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) + if err != nil { + return nil, err + } + return pats, nil +} + +func (s *SqlStore) getPolicyRules(ctx context.Context, policyIDs []string) ([]*types.PolicyRule, error) { + if len(policyIDs) == 0 { + return nil, nil + } + const query = `SELECT id, policy_id, name, description, enabled, action, destinations, destination_resource, sources, source_resource, bidirectional, protocol, ports, port_ranges FROM policy_rules WHERE policy_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, policyIDs) + if err != nil { + return nil, err + } + rules, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { + var r types.PolicyRule + var dest, destRes, sources, sourceRes, ports, portRanges []byte + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } + if dest != nil { + _ = json.Unmarshal(dest, &r.Destinations) + } + if destRes != nil { + _ = json.Unmarshal(destRes, &r.DestinationResource) + } + if sources != nil { + _ = json.Unmarshal(sources, &r.Sources) + } + if sourceRes != nil { + _ = json.Unmarshal(sourceRes, &r.SourceResource) + } + if ports != nil { + _ = json.Unmarshal(ports, &r.Ports) + } + if portRanges != nil { + _ = json.Unmarshal(portRanges, &r.PortRanges) + } + } + return &r, err + }) + if err != nil { + return nil, err + } + return rules, nil +} + +func (s *SqlStore) getGroupPeers(ctx context.Context, groupIDs []string) ([]types.GroupPeer, error) { + if len(groupIDs) == 0 { + return nil, nil + } + const query = `SELECT account_id, group_id, peer_id FROM group_peers WHERE group_id = ANY($1)` + rows, err := s.pool.Query(ctx, query, groupIDs) + if err != nil { + return nil, err + } + groupPeers, err := pgx.CollectRows(rows, pgx.RowToStructByName[types.GroupPeer]) + if err != nil { + return nil, err + } + return groupPeers, nil +} + func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) { var user types.User result := s.db.Select("account_id").Take(&user, idQueryCondition, userID) @@ -1199,8 +2301,41 @@ func NewPostgresqlStore(ctx context.Context, dsn string, metrics telemetry.AppMe if err != nil { return nil, err } + pool, err := connectToPgDb(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} - return NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) +func connectToPgDb(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = pgMaxConnections + config.MinConns = pgMinConnections + config.MaxConnLifetime = pgMaxConnLifetime + config.HealthCheckPeriod = pgHealthCheckPeriod + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil } // NewMysqlStore creates a new MySQL store. @@ -1269,7 +2404,7 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data // NewPostgresqlStoreFromSqlStore restores a store from SqlStore and stores Postgres DB. func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { - store, err := NewPostgresqlStore(ctx, dsn, metrics, false) + store, err := NewPostgresqlStoreForTests(ctx, dsn, metrics, false) if err != nil { return nil, err } @@ -1289,6 +2424,50 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, return store, nil } +// used for tests only +func NewPostgresqlStoreForTests(ctx context.Context, dsn string, metrics telemetry.AppMetrics, skipMigration bool) (*SqlStore, error) { + db, err := gorm.Open(postgres.Open(dsn), getGormConfig()) + if err != nil { + return nil, err + } + pool, err := connectToPgDbForTests(context.Background(), dsn) + if err != nil { + return nil, err + } + store, err := NewSqlStore(ctx, db, types.PostgresStoreEngine, metrics, skipMigration) + if err != nil { + pool.Close() + return nil, err + } + store.pool = pool + return store, nil +} + +// used for tests only +func connectToPgDbForTests(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 5 + config.MinConns = 1 + config.MaxConnLifetime = 30 * time.Second + config.HealthCheckPeriod = 10 * time.Second + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + + return pool, nil +} + // NewMysqlStoreFromSqlStore restores a store from SqlStore and stores MySQL DB. func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn string, metrics telemetry.AppMetrics) (*SqlStore, error) { store, err := NewMysqlStore(ctx, dsn, metrics, false) @@ -1735,6 +2914,33 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor if tx.Error != nil { return tx.Error } + defer func() { + if r := recover(); r != nil { + tx.Rollback() + panic(r) + } + }() + + if s.storeEngine == types.PostgresStoreEngine { + if err := tx.Exec("SET LOCAL statement_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set statement timeout: %w", err) + } + if err := tx.Exec("SET LOCAL lock_timeout = '1min'").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to set lock timeout: %w", err) + } + } + + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + repo := s.withTx(tx) err := operation(repo) if err != nil { @@ -1742,6 +2948,14 @@ func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(stor return err } + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; err != nil { + tx.Rollback() + return fmt.Errorf("failed to re-enable FK checks: %w", err) + } + } + err = tx.Commit().Error log.WithContext(ctx).Tracef("transaction took %v", time.Since(startTime)) @@ -1759,6 +2973,31 @@ func (s *SqlStore) withTx(tx *gorm.DB) Store { } } +// transaction wraps a GORM transaction with MySQL-specific FK checks handling +// Use this instead of db.Transaction() directly to avoid deadlocks on MySQL/Aurora +func (s *SqlStore) transaction(fn func(*gorm.DB) error) error { + return s.db.Transaction(func(tx *gorm.DB) error { + // For MySQL, disable FK checks within this transaction to avoid deadlocks + // This is session-scoped and doesn't require SUPER privileges + if s.storeEngine == types.MysqlStoreEngine { + if err := tx.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { + return fmt.Errorf("failed to disable FK checks: %w", err) + } + } + + err := fn(tx) + + // Re-enable FK checks before commit (optional, as transaction end resets it) + if s.storeEngine == types.MysqlStoreEngine && err == nil { + if fkErr := tx.Exec("SET FOREIGN_KEY_CHECKS = 1").Error; fkErr != nil { + return fmt.Errorf("failed to re-enable FK checks: %w", fkErr) + } + } + + return err + }) +} + func (s *SqlStore) GetDB() *gorm.DB { return s.db } @@ -2015,7 +3254,7 @@ func (s *SqlStore) SavePolicy(ctx context.Context, policy *types.Policy) error { } func (s *SqlStore) DeletePolicy(ctx context.Context, accountID, policyID string) error { - return s.db.Transaction(func(tx *gorm.DB) error { + return s.transaction(func(tx *gorm.DB) error { if err := tx.Where("policy_id = ?", policyID).Delete(&types.PolicyRule{}).Error; err != nil { return fmt.Errorf("delete policy rules: %w", err) } diff --git a/management/server/store/sql_store_get_account_test.go b/management/server/store/sql_store_get_account_test.go new file mode 100644 index 000000000..8ff04d68a --- /dev/null +++ b/management/server/store/sql_store_get_account_test.go @@ -0,0 +1,1089 @@ +package store + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + nbdns "github.com/netbirdio/netbird/dns" + "github.com/netbirdio/netbird/management/server/integration_reference" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads +// all fields and nested objects from the database, including deeply nested structures. +func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) { + if testing.Short() { + t.Skip("skipping comprehensive test in short mode") + } + + ctx := context.Background() + store, cleanup, err := NewTestStoreFromSQL(ctx, "", t.TempDir()) + require.NoError(t, err) + defer cleanup() + + // Create comprehensive test data + accountID := "test-account-comprehensive" + userID1 := "user-1" + userID2 := "user-2" + peerID1 := "peer-1" + peerID2 := "peer-2" + peerID3 := "peer-3" + groupID1 := "group-1" + groupID2 := "group-2" + setupKeyID1 := "setup-key-1" + setupKeyID2 := "setup-key-2" + routeID1 := route.ID("route-1") + routeID2 := route.ID("route-2") + nsGroupID1 := "ns-group-1" + nsGroupID2 := "ns-group-2" + policyID1 := "policy-1" + policyID2 := "policy-2" + postureCheckID1 := "posture-check-1" + postureCheckID2 := "posture-check-2" + networkID1 := "network-1" + routerID1 := "router-1" + resourceID1 := "resource-1" + patID1 := "pat-1" + patID2 := "pat-2" + patID3 := "pat-3" + + now := time.Now().UTC().Truncate(time.Second) + lastLogin := now.Add(-24 * time.Hour) + patLastUsed := now.Add(-1 * time.Hour) + + // Build comprehensive account with all fields populated + account := &types.Account{ + Id: accountID, + CreatedBy: userID1, + CreatedAt: now, + Domain: "example.com", + DomainCategory: "business", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "test-network", + Net: net.IPNet{ + IP: net.ParseIP("100.64.0.0"), + Mask: net.CIDRMask(10, 32), + }, + Dns: "test-dns", + Serial: 42, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"dns-group-1", "dns-group-2"}, + }, + Settings: &types.Settings{ + PeerLoginExpirationEnabled: true, + PeerLoginExpiration: time.Hour * 24 * 30, + GroupsPropagationEnabled: true, + JWTGroupsEnabled: true, + JWTGroupsClaimName: "groups", + JWTAllowGroups: []string{"allowed-group-1", "allowed-group-2"}, + RegularUsersViewBlocked: false, + Extra: &types.ExtraSettings{ + PeerApprovalEnabled: true, + IntegratedValidatorGroups: []string{"validator-1"}, + }, + }, + } + + // Create Setup Keys with all fields + setupKey1ExpiresAt := now.Add(30 * 24 * time.Hour) + setupKey1LastUsed := now.Add(-2 * time.Hour) + setupKey1 := &types.SetupKey{ + Id: setupKeyID1, + AccountID: accountID, + Key: "setup-key-secret-1", + Name: "Setup Key 1", + Type: types.SetupKeyReusable, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey1ExpiresAt, + Revoked: false, + UsedTimes: 5, + LastUsed: &setupKey1LastUsed, + AutoGroups: []string{groupID1, groupID2}, + UsageLimit: 100, + Ephemeral: false, + } + + setupKey2ExpiresAt := now.Add(7 * 24 * time.Hour) + setupKey2LastUsed := now.Add(-1 * time.Hour) + setupKey2 := &types.SetupKey{ + Id: setupKeyID2, + AccountID: accountID, + Key: "setup-key-secret-2", + Name: "Setup Key 2 (One-off)", + Type: types.SetupKeyOneOff, + CreatedAt: now, + UpdatedAt: now, + ExpiresAt: &setupKey2ExpiresAt, + Revoked: true, + UsedTimes: 1, + LastUsed: &setupKey2LastUsed, + AutoGroups: []string{}, + UsageLimit: 1, + Ephemeral: true, + } + + account.SetupKeys = map[string]*types.SetupKey{ + setupKey1.Key: setupKey1, + setupKey2.Key: setupKey2, + } + + // Create Peers with comprehensive fields + peer1 := &nbpeer.Peer{ + ID: peerID1, + AccountID: accountID, + Key: "peer-key-1-AAAA", + Name: "Peer 1", + IP: net.ParseIP("100.64.0.1"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer1.example.com", + GoOS: "linux", + Kernel: "5.15.0", + Core: "x86_64", + Platform: "ubuntu", + OS: "Ubuntu 22.04", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + KernelVersion: "5.15.0-78-generic", + OSVersion: "22.04", + NetworkAddresses: []nbpeer.NetworkAddress{ + {NetIP: netip.MustParsePrefix("192.168.1.10/32"), Mac: "00:11:22:33:44:55"}, + {NetIP: netip.MustParsePrefix("10.0.0.5/32"), Mac: "00:11:22:33:44:66"}, + }, + SystemSerialNumber: "ABC123", + SystemProductName: "Server Model X", + SystemManufacturer: "Dell Inc.", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-5 * time.Minute), + Connected: true, + LoginExpired: false, + RequiresApproval: false, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("203.0.113.10"), + CountryCode: "US", + CityName: "San Francisco", + GeoNameID: 5391959, + }, + SSHEnabled: true, + SSHKey: "ssh-rsa AAAAB3NzaC1...", + UserID: userID1, + LoginExpirationEnabled: true, + InactivityExpirationEnabled: false, + DNSLabel: "peer1", + CreatedAt: now.Add(-30 * 24 * time.Hour), + Ephemeral: false, + } + + peer2 := &nbpeer.Peer{ + ID: peerID2, + AccountID: accountID, + Key: "peer-key-2-BBBB", + Name: "Peer 2", + IP: net.ParseIP("100.64.0.2"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer2.example.com", + GoOS: "darwin", + Kernel: "22.0.0", + Core: "arm64", + Platform: "darwin", + OS: "macOS Ventura", + WtVersion: "0.24.0", + UIVersion: "0.24.0", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-1 * time.Hour), + Connected: false, + LoginExpired: true, + RequiresApproval: true, + }, + Location: nbpeer.Location{ + ConnectionIP: net.ParseIP("198.51.100.20"), + CountryCode: "GB", + CityName: "London", + GeoNameID: 2643743, + }, + SSHEnabled: false, + UserID: userID2, + LoginExpirationEnabled: false, + InactivityExpirationEnabled: true, + DNSLabel: "peer2", + CreatedAt: now.Add(-15 * 24 * time.Hour), + Ephemeral: false, + } + + peer3 := &nbpeer.Peer{ + ID: peerID3, + AccountID: accountID, + Key: "peer-key-3-CCCC", + Name: "Peer 3 (Ephemeral)", + IP: net.ParseIP("100.64.0.3"), + Meta: nbpeer.PeerSystemMeta{ + Hostname: "peer3.example.com", + GoOS: "windows", + Platform: "windows", + }, + Status: &nbpeer.PeerStatus{ + LastSeen: now.Add(-10 * time.Minute), + Connected: true, + }, + DNSLabel: "peer3", + CreatedAt: now.Add(-1 * time.Hour), + Ephemeral: true, + } + + account.Peers = map[string]*nbpeer.Peer{ + peerID1: peer1, + peerID2: peer2, + peerID3: peer3, + } + + // Create Users with PATs + pat1ExpirationDate := now.Add(90 * 24 * time.Hour) + pat1 := &types.PersonalAccessToken{ + ID: patID1, + Name: "PAT 1", + HashedToken: "hashed-token-1", + ExpirationDate: &pat1ExpirationDate, + CreatedAt: now.Add(-10 * 24 * time.Hour), + CreatedBy: userID1, + LastUsed: &patLastUsed, + } + + pat2ExpirationDate := now.Add(30 * 24 * time.Hour) + pat2 := &types.PersonalAccessToken{ + ID: patID2, + Name: "PAT 2", + HashedToken: "hashed-token-2", + ExpirationDate: &pat2ExpirationDate, + CreatedAt: now.Add(-5 * 24 * time.Hour), + CreatedBy: userID1, + } + + pat3ExpirationDate := now.Add(60 * 24 * time.Hour) + pat3 := &types.PersonalAccessToken{ + ID: patID3, + Name: "PAT 3", + HashedToken: "hashed-token-3", + ExpirationDate: &pat3ExpirationDate, + CreatedAt: now.Add(-2 * 24 * time.Hour), + CreatedBy: userID2, + } + + user1 := &types.User{ + Id: userID1, + AccountID: accountID, + Role: types.UserRoleOwner, + IsServiceUser: false, + NonDeletable: true, + AutoGroups: []string{groupID1}, + Issued: types.UserIssuedAPI, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 123, + IntegrationType: "azure_ad", + }, + CreatedAt: now.Add(-60 * 24 * time.Hour), + LastLogin: &lastLogin, + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID1: pat1, + patID2: pat2, + }, + } + + user2 := &types.User{ + Id: userID2, + AccountID: accountID, + Role: types.UserRoleAdmin, + IsServiceUser: true, + NonDeletable: false, + AutoGroups: []string{groupID2}, + Issued: types.UserIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 456, + IntegrationType: "google_workspace", + }, + CreatedAt: now.Add(-30 * 24 * time.Hour), + Blocked: false, + PATs: map[string]*types.PersonalAccessToken{ + patID3: pat3, + }, + } + + account.Users = map[string]*types.User{ + userID1: user1, + userID2: user2, + } + + // Create Groups with peers and resources + group1 := &types.Group{ + ID: groupID1, + AccountID: accountID, + Name: "Group 1", + Issued: types.GroupIssuedAPI, + Peers: []string{peerID1, peerID2}, + Resources: []types.Resource{ + { + ID: "resource-1", + Type: types.ResourceTypeHost, + }, + }, + } + + group2 := &types.Group{ + ID: groupID2, + AccountID: accountID, + Name: "Group 2", + Issued: types.GroupIssuedIntegration, + IntegrationReference: integration_reference.IntegrationReference{ + ID: 789, + IntegrationType: "okta", + }, + Peers: []string{peerID3}, + Resources: []types.Resource{}, + } + + account.Groups = map[string]*types.Group{ + groupID1: group1, + groupID2: group2, + } + + // Create Policies with Rules + policy1 := &types.Policy{ + ID: policyID1, + AccountID: accountID, + Name: "Policy 1", + Description: "Main access policy", + Enabled: true, + Rules: []*types.PolicyRule{ + { + ID: "rule-1", + PolicyID: policyID1, + Name: "Rule 1", + Description: "Allow access", + Enabled: true, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolALL, + Ports: []string{}, + PortRanges: []types.RulePortRange{}, + Sources: []string{groupID1}, + Destinations: []string{groupID2}, + }, + { + ID: "rule-2", + PolicyID: policyID1, + Name: "Rule 2", + Description: "Block traffic on specific ports", + Enabled: true, + Action: types.PolicyTrafficActionDrop, + Bidirectional: false, + Protocol: types.PolicyRuleProtocolTCP, + Ports: []string{"22", "3389"}, + PortRanges: []types.RulePortRange{ + {Start: 8000, End: 8999}, + }, + Sources: []string{groupID2}, + Destinations: []string{groupID1}, + }, + }, + } + + policy2 := &types.Policy{ + ID: policyID2, + AccountID: accountID, + Name: "Policy 2", + Description: "Secondary policy", + Enabled: false, + Rules: []*types.PolicyRule{ + { + ID: "rule-3", + PolicyID: policyID2, + Name: "Rule 3", + Description: "UDP access", + Enabled: false, + Action: types.PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: types.PolicyRuleProtocolUDP, + Ports: []string{"53"}, + Sources: []string{groupID1}, + Destinations: []string{groupID1}, + }, + }, + } + + account.Policies = []*types.Policy{policy1, policy2} + + // Create Routes + route1 := &route.Route{ + ID: routeID1, + AccountID: accountID, + Network: netip.MustParsePrefix("10.0.0.0/24"), + NetworkType: route.IPv4Network, + Peer: peerID1, + PeerGroups: []string{}, + Description: "Route 1", + NetID: "net-id-1", + Masquerade: true, + Metric: 9999, + Enabled: true, + Groups: []string{groupID1}, + AccessControlGroups: []string{groupID2}, + } + + route2 := &route.Route{ + ID: routeID2, + AccountID: accountID, + Network: netip.MustParsePrefix("192.168.1.0/24"), + NetworkType: route.IPv4Network, + Peer: "", + PeerGroups: []string{groupID2}, + Description: "Route 2 (High Availability)", + NetID: "net-id-2", + Masquerade: false, + Metric: 100, + Enabled: true, + Groups: []string{groupID1, groupID2}, + AccessControlGroups: []string{groupID1}, + } + + account.Routes = map[route.ID]*route.Route{ + routeID1: route1, + routeID2: route2, + } + + // Create NameServer Groups + nsGroup1 := &nbdns.NameServerGroup{ + ID: nsGroupID1, + AccountID: accountID, + Name: "NS Group 1", + Description: "Primary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("8.8.8.8"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + { + IP: netip.MustParseAddr("8.8.4.4"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{groupID1, groupID2}, + Domains: []string{"example.com", "test.com"}, + Enabled: true, + Primary: true, + SearchDomainsEnabled: true, + } + + nsGroup2 := &nbdns.NameServerGroup{ + ID: nsGroupID2, + AccountID: accountID, + Name: "NS Group 2", + Description: "Secondary nameservers", + NameServers: []nbdns.NameServer{ + { + IP: netip.MustParseAddr("1.1.1.1"), + NSType: nbdns.UDPNameServerType, + Port: 53, + }, + }, + Groups: []string{}, + Domains: []string{}, + Enabled: false, + Primary: false, + SearchDomainsEnabled: false, + } + + account.NameServerGroups = map[string]*nbdns.NameServerGroup{ + nsGroupID1: nsGroup1, + nsGroupID2: nsGroup2, + } + + // Create Posture Checks + postureCheck1 := &posture.Checks{ + ID: postureCheckID1, + AccountID: accountID, + Name: "Posture Check 1", + Description: "OS version check", + Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{ + MinVersion: "0.24.0", + }, + OSVersionCheck: &posture.OSVersionCheck{ + Ios: &posture.MinVersionCheck{ + MinVersion: "16.0", + }, + Darwin: &posture.MinVersionCheck{ + MinVersion: "22.0.0", + }, + }, + }, + } + + postureCheck2 := &posture.Checks{ + ID: postureCheckID2, + AccountID: accountID, + Name: "Posture Check 2", + Description: "Geo location check", + Checks: posture.ChecksDefinition{ + GeoLocationCheck: &posture.GeoLocationCheck{ + Locations: []posture.Location{ + { + CountryCode: "US", + CityName: "San Francisco", + }, + { + CountryCode: "GB", + CityName: "London", + }, + }, + Action: "allow", + }, + PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{ + Ranges: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/16"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + Action: "allow", + }, + }, + } + + account.PostureChecks = []*posture.Checks{postureCheck1, postureCheck2} + + // Create Networks + network1 := &networkTypes.Network{ + ID: networkID1, + AccountID: accountID, + Name: "Network 1", + Description: "Primary network", + } + + account.Networks = []*networkTypes.Network{network1} + + // Create Network Routers + router1 := &routerTypes.NetworkRouter{ + ID: routerID1, + AccountID: accountID, + NetworkID: networkID1, + Peer: peerID1, + PeerGroups: []string{}, + Masquerade: true, + Metric: 100, + } + + account.NetworkRouters = []*routerTypes.NetworkRouter{router1} + + // Create Network Resources + resource1 := &resourceTypes.NetworkResource{ + ID: resourceID1, + AccountID: accountID, + NetworkID: networkID1, + Name: "Resource 1", + Description: "Web server", + Prefix: netip.MustParsePrefix("192.168.1.100/32"), + Type: resourceTypes.Host, + } + + account.NetworkResources = []*resourceTypes.NetworkResource{resource1} + + // Create Onboarding + account.Onboarding = types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + SignupFormPending: false, + CreatedAt: now, + UpdatedAt: now, + } + + // Save the account to the database + err = store.SaveAccount(ctx, account) + require.NoError(t, err, "Failed to save comprehensive test account") + + // Retrieve the account from the database + retrievedAccount, err := store.GetAccount(ctx, accountID) + require.NoError(t, err, "Failed to retrieve account") + require.NotNil(t, retrievedAccount, "Retrieved account should not be nil") + + // ========== VALIDATE TOP-LEVEL FIELDS ========== + t.Run("TopLevelFields", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Id, "Account ID mismatch") + assert.Equal(t, userID1, retrievedAccount.CreatedBy, "CreatedBy mismatch") + assert.WithinDuration(t, now, retrievedAccount.CreatedAt, time.Second, "CreatedAt mismatch") + assert.Equal(t, "example.com", retrievedAccount.Domain, "Domain mismatch") + assert.Equal(t, "business", retrievedAccount.DomainCategory, "DomainCategory mismatch") + assert.True(t, retrievedAccount.IsDomainPrimaryAccount, "IsDomainPrimaryAccount should be true") + }) + + // ========== VALIDATE EMBEDDED NETWORK ========== + t.Run("EmbeddedNetwork", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Network, "Network should not be nil") + assert.Equal(t, "test-network", retrievedAccount.Network.Identifier, "Network Identifier mismatch") + assert.Equal(t, "test-dns", retrievedAccount.Network.Dns, "Network DNS mismatch") + assert.Equal(t, uint64(42), retrievedAccount.Network.Serial, "Network Serial mismatch") + + expectedIP := net.ParseIP("100.64.0.0") + assert.True(t, retrievedAccount.Network.Net.IP.Equal(expectedIP), "Network IP mismatch") + expectedMask := net.CIDRMask(10, 32) + assert.Equal(t, expectedMask, retrievedAccount.Network.Net.Mask, "Network Mask mismatch") + }) + + // ========== VALIDATE DNS SETTINGS ========== + t.Run("DNSSettings", func(t *testing.T) { + assert.Len(t, retrievedAccount.DNSSettings.DisabledManagementGroups, 2, "DisabledManagementGroups length mismatch") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-1", "Missing dns-group-1") + assert.Contains(t, retrievedAccount.DNSSettings.DisabledManagementGroups, "dns-group-2", "Missing dns-group-2") + }) + + // ========== VALIDATE SETTINGS ========== + t.Run("Settings", func(t *testing.T) { + require.NotNil(t, retrievedAccount.Settings, "Settings should not be nil") + assert.True(t, retrievedAccount.Settings.PeerLoginExpirationEnabled, "PeerLoginExpirationEnabled mismatch") + assert.Equal(t, time.Hour*24*30, retrievedAccount.Settings.PeerLoginExpiration, "PeerLoginExpiration mismatch") + assert.True(t, retrievedAccount.Settings.GroupsPropagationEnabled, "GroupsPropagationEnabled mismatch") + assert.True(t, retrievedAccount.Settings.JWTGroupsEnabled, "JWTGroupsEnabled mismatch") + assert.Equal(t, "groups", retrievedAccount.Settings.JWTGroupsClaimName, "JWTGroupsClaimName mismatch") + assert.Len(t, retrievedAccount.Settings.JWTAllowGroups, 2, "JWTAllowGroups length mismatch") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-1") + assert.Contains(t, retrievedAccount.Settings.JWTAllowGroups, "allowed-group-2") + assert.False(t, retrievedAccount.Settings.RegularUsersViewBlocked, "RegularUsersViewBlocked mismatch") + + // Validate Extra Settings + require.NotNil(t, retrievedAccount.Settings.Extra, "Extra settings should not be nil") + assert.True(t, retrievedAccount.Settings.Extra.PeerApprovalEnabled, "PeerApprovalEnabled mismatch") + assert.Len(t, retrievedAccount.Settings.Extra.IntegratedValidatorGroups, 1, "IntegratedValidatorGroups length mismatch") + assert.Equal(t, "validator-1", retrievedAccount.Settings.Extra.IntegratedValidatorGroups[0]) + }) + + // ========== VALIDATE SETUP KEYS ========== + t.Run("SetupKeys", func(t *testing.T) { + require.Len(t, retrievedAccount.SetupKeys, 2, "Should have 2 setup keys") + + // Validate Setup Key 1 + sk1, exists := retrievedAccount.SetupKeys["setup-key-secret-1"] + require.True(t, exists, "Setup key 1 should exist") + assert.Equal(t, "Setup Key 1", sk1.Name, "Setup key 1 name mismatch") + assert.Equal(t, types.SetupKeyReusable, sk1.Type, "Setup key 1 type mismatch") + assert.False(t, sk1.Revoked, "Setup key 1 should not be revoked") + assert.Equal(t, 5, sk1.UsedTimes, "Setup key 1 used times mismatch") + assert.Equal(t, 100, sk1.UsageLimit, "Setup key 1 usage limit mismatch") + assert.False(t, sk1.Ephemeral, "Setup key 1 should not be ephemeral") + assert.Len(t, sk1.AutoGroups, 2, "Setup key 1 auto groups length mismatch") + assert.Contains(t, sk1.AutoGroups, groupID1) + assert.Contains(t, sk1.AutoGroups, groupID2) + + // Validate Setup Key 2 + sk2, exists := retrievedAccount.SetupKeys["setup-key-secret-2"] + require.True(t, exists, "Setup key 2 should exist") + assert.Equal(t, "Setup Key 2 (One-off)", sk2.Name, "Setup key 2 name mismatch") + assert.Equal(t, types.SetupKeyOneOff, sk2.Type, "Setup key 2 type mismatch") + assert.True(t, sk2.Revoked, "Setup key 2 should be revoked") + assert.Equal(t, 1, sk2.UsedTimes, "Setup key 2 used times mismatch") + assert.Equal(t, 1, sk2.UsageLimit, "Setup key 2 usage limit mismatch") + assert.True(t, sk2.Ephemeral, "Setup key 2 should be ephemeral") + assert.Len(t, sk2.AutoGroups, 0, "Setup key 2 should have empty auto groups") + }) + + // ========== VALIDATE PEERS ========== + t.Run("Peers", func(t *testing.T) { + require.Len(t, retrievedAccount.Peers, 3, "Should have 3 peers") + + // Validate Peer 1 + p1, exists := retrievedAccount.Peers[peerID1] + require.True(t, exists, "Peer 1 should exist") + assert.Equal(t, "Peer 1", p1.Name, "Peer 1 name mismatch") + assert.Equal(t, "peer-key-1-AAAA", p1.Key, "Peer 1 key mismatch") + assert.True(t, p1.IP.Equal(net.ParseIP("100.64.0.1")), "Peer 1 IP mismatch") + assert.Equal(t, userID1, p1.UserID, "Peer 1 user ID mismatch") + assert.True(t, p1.SSHEnabled, "Peer 1 SSH should be enabled") + assert.Equal(t, "ssh-rsa AAAAB3NzaC1...", p1.SSHKey, "Peer 1 SSH key mismatch") + assert.True(t, p1.LoginExpirationEnabled, "Peer 1 login expiration should be enabled") + assert.False(t, p1.Ephemeral, "Peer 1 should not be ephemeral") + assert.Equal(t, "peer1", p1.DNSLabel, "Peer 1 DNS label mismatch") + + // Validate Peer 1 Meta + assert.Equal(t, "peer1.example.com", p1.Meta.Hostname, "Peer 1 hostname mismatch") + assert.Equal(t, "linux", p1.Meta.GoOS, "Peer 1 OS mismatch") + assert.Equal(t, "5.15.0", p1.Meta.Kernel, "Peer 1 kernel mismatch") + assert.Equal(t, "x86_64", p1.Meta.Core, "Peer 1 core mismatch") + assert.Equal(t, "ubuntu", p1.Meta.Platform, "Peer 1 platform mismatch") + assert.Equal(t, "Ubuntu 22.04", p1.Meta.OS, "Peer 1 OS version mismatch") + assert.Equal(t, "0.24.0", p1.Meta.WtVersion, "Peer 1 wt version mismatch") + assert.Equal(t, "ABC123", p1.Meta.SystemSerialNumber, "Peer 1 serial number mismatch") + assert.Equal(t, "Server Model X", p1.Meta.SystemProductName, "Peer 1 product name mismatch") + assert.Equal(t, "Dell Inc.", p1.Meta.SystemManufacturer, "Peer 1 manufacturer mismatch") + + // Validate Network Addresses + assert.Len(t, p1.Meta.NetworkAddresses, 2, "Peer 1 should have 2 network addresses") + assert.Equal(t, netip.MustParsePrefix("192.168.1.10/32"), p1.Meta.NetworkAddresses[0].NetIP, "Network address 1 IP mismatch") + assert.Equal(t, "00:11:22:33:44:55", p1.Meta.NetworkAddresses[0].Mac, "Network address 1 MAC mismatch") + assert.Equal(t, netip.MustParsePrefix("10.0.0.5/32"), p1.Meta.NetworkAddresses[1].NetIP, "Network address 2 IP mismatch") + assert.Equal(t, "00:11:22:33:44:66", p1.Meta.NetworkAddresses[1].Mac, "Network address 2 MAC mismatch") + + // Validate Peer 1 Status + require.NotNil(t, p1.Status, "Peer 1 status should not be nil") + assert.True(t, p1.Status.Connected, "Peer 1 should be connected") + assert.False(t, p1.Status.LoginExpired, "Peer 1 login should not be expired") + assert.False(t, p1.Status.RequiresApproval, "Peer 1 should not require approval") + + // Validate Peer 1 Location + assert.True(t, p1.Location.ConnectionIP.Equal(net.ParseIP("203.0.113.10")), "Peer 1 connection IP mismatch") + assert.Equal(t, "US", p1.Location.CountryCode, "Peer 1 country code mismatch") + assert.Equal(t, "San Francisco", p1.Location.CityName, "Peer 1 city name mismatch") + assert.Equal(t, uint(5391959), p1.Location.GeoNameID, "Peer 1 geo name ID mismatch") + + // Validate Peer 2 + p2, exists := retrievedAccount.Peers[peerID2] + require.True(t, exists, "Peer 2 should exist") + assert.Equal(t, "Peer 2", p2.Name, "Peer 2 name mismatch") + assert.Equal(t, "peer-key-2-BBBB", p2.Key, "Peer 2 key mismatch") + assert.False(t, p2.SSHEnabled, "Peer 2 SSH should be disabled") + assert.False(t, p2.LoginExpirationEnabled, "Peer 2 login expiration should be disabled") + assert.True(t, p2.InactivityExpirationEnabled, "Peer 2 inactivity expiration should be enabled") + + // Validate Peer 2 Status + require.NotNil(t, p2.Status, "Peer 2 status should not be nil") + assert.False(t, p2.Status.Connected, "Peer 2 should not be connected") + assert.True(t, p2.Status.LoginExpired, "Peer 2 login should be expired") + assert.True(t, p2.Status.RequiresApproval, "Peer 2 should require approval") + + // Validate Peer 3 (Ephemeral) + p3, exists := retrievedAccount.Peers[peerID3] + require.True(t, exists, "Peer 3 should exist") + assert.True(t, p3.Ephemeral, "Peer 3 should be ephemeral") + assert.Equal(t, "Peer 3 (Ephemeral)", p3.Name, "Peer 3 name mismatch") + }) + + // ========== VALIDATE USERS ========== + t.Run("Users", func(t *testing.T) { + require.Len(t, retrievedAccount.Users, 2, "Should have 2 users") + + // Validate User 1 + u1, exists := retrievedAccount.Users[userID1] + require.True(t, exists, "User 1 should exist") + assert.Equal(t, types.UserRoleOwner, u1.Role, "User 1 role mismatch") + assert.False(t, u1.IsServiceUser, "User 1 should not be a service user") + assert.True(t, u1.NonDeletable, "User 1 should be non-deletable") + assert.Equal(t, types.UserIssuedAPI, u1.Issued, "User 1 issued type mismatch") + assert.Len(t, u1.AutoGroups, 1, "User 1 auto groups length mismatch") + assert.Contains(t, u1.AutoGroups, groupID1, "User 1 should have group1") + assert.False(t, u1.Blocked, "User 1 should not be blocked") + require.NotNil(t, u1.LastLogin, "User 1 last login should not be nil") + assert.WithinDuration(t, lastLogin, *u1.LastLogin, time.Second, "User 1 last login mismatch") + + // Validate User 1 Integration Reference + assert.Equal(t, 123, u1.IntegrationReference.ID, "User 1 integration ID mismatch") + assert.Equal(t, "azure_ad", u1.IntegrationReference.IntegrationType, "User 1 integration type mismatch") + + // Validate User 1 PATs + require.Len(t, u1.PATs, 2, "User 1 should have 2 PATs") + + pat1Retrieved, exists := u1.PATs[patID1] + require.True(t, exists, "PAT 1 should exist") + assert.Equal(t, "PAT 1", pat1Retrieved.Name, "PAT 1 name mismatch") + assert.Equal(t, "hashed-token-1", pat1Retrieved.HashedToken, "PAT 1 hashed token mismatch") + require.NotNil(t, pat1Retrieved.LastUsed, "PAT 1 last used should not be nil") + assert.WithinDuration(t, patLastUsed, *pat1Retrieved.LastUsed, time.Second, "PAT 1 last used mismatch") + assert.Equal(t, userID1, pat1Retrieved.CreatedBy, "PAT 1 created by mismatch") + assert.Empty(t, pat1Retrieved.UserID, "PAT 1 UserID should be cleared") + + pat2Retrieved, exists := u1.PATs[patID2] + require.True(t, exists, "PAT 2 should exist") + assert.Equal(t, "PAT 2", pat2Retrieved.Name, "PAT 2 name mismatch") + assert.Nil(t, pat2Retrieved.LastUsed, "PAT 2 last used should be nil") + + // Validate User 2 + u2, exists := retrievedAccount.Users[userID2] + require.True(t, exists, "User 2 should exist") + assert.Equal(t, types.UserRoleAdmin, u2.Role, "User 2 role mismatch") + assert.True(t, u2.IsServiceUser, "User 2 should be a service user") + assert.False(t, u2.NonDeletable, "User 2 should be deletable") + assert.Equal(t, types.UserIssuedIntegration, u2.Issued, "User 2 issued type mismatch") + assert.Equal(t, "google_workspace", u2.IntegrationReference.IntegrationType, "User 2 integration type mismatch") + + // Validate User 2 PATs + require.Len(t, u2.PATs, 1, "User 2 should have 1 PAT") + pat3Retrieved, exists := u2.PATs[patID3] + require.True(t, exists, "PAT 3 should exist") + assert.Equal(t, "PAT 3", pat3Retrieved.Name, "PAT 3 name mismatch") + }) + + // ========== VALIDATE GROUPS ========== + t.Run("Groups", func(t *testing.T) { + require.Len(t, retrievedAccount.Groups, 2, "Should have 2 groups") + + // Validate Group 1 + g1, exists := retrievedAccount.Groups[groupID1] + require.True(t, exists, "Group 1 should exist") + assert.Equal(t, "Group 1", g1.Name, "Group 1 name mismatch") + assert.Equal(t, types.GroupIssuedAPI, g1.Issued, "Group 1 issued type mismatch") + assert.Len(t, g1.Peers, 2, "Group 1 should have 2 peers") + assert.Contains(t, g1.Peers, peerID1, "Group 1 should contain peer 1") + assert.Contains(t, g1.Peers, peerID2, "Group 1 should contain peer 2") + + // Validate Group 1 Resources + assert.Len(t, g1.Resources, 1, "Group 1 should have 1 resource") + assert.Equal(t, "resource-1", g1.Resources[0].ID, "Group 1 resource ID mismatch") + assert.Equal(t, types.ResourceTypeHost, g1.Resources[0].Type, "Group 1 resource type mismatch") + + // Validate Group 2 + g2, exists := retrievedAccount.Groups[groupID2] + require.True(t, exists, "Group 2 should exist") + assert.Equal(t, "Group 2", g2.Name, "Group 2 name mismatch") + assert.Equal(t, types.GroupIssuedIntegration, g2.Issued, "Group 2 issued type mismatch") + assert.Len(t, g2.Peers, 1, "Group 2 should have 1 peer") + assert.Contains(t, g2.Peers, peerID3, "Group 2 should contain peer 3") + assert.Len(t, g2.Resources, 0, "Group 2 should have 0 resources") + + // Validate Group 2 Integration Reference + assert.Equal(t, 789, g2.IntegrationReference.ID, "Group 2 integration ID mismatch") + assert.Equal(t, "okta", g2.IntegrationReference.IntegrationType, "Group 2 integration type mismatch") + }) + + // ========== VALIDATE POLICIES ========== + t.Run("Policies", func(t *testing.T) { + require.Len(t, retrievedAccount.Policies, 2, "Should have 2 policies") + + // Validate Policy 1 + pol1 := retrievedAccount.Policies[0] + if pol1.ID != policyID1 { + pol1 = retrievedAccount.Policies[1] + } + assert.Equal(t, policyID1, pol1.ID, "Policy 1 ID mismatch") + assert.Equal(t, "Policy 1", pol1.Name, "Policy 1 name mismatch") + assert.Equal(t, "Main access policy", pol1.Description, "Policy 1 description mismatch") + assert.True(t, pol1.Enabled, "Policy 1 should be enabled") + + // Validate Policy 1 Rules + require.Len(t, pol1.Rules, 2, "Policy 1 should have 2 rules") + + rule1 := pol1.Rules[0] + assert.Equal(t, "Rule 1", rule1.Name, "Rule 1 name mismatch") + assert.Equal(t, "Allow access", rule1.Description, "Rule 1 description mismatch") + assert.True(t, rule1.Enabled, "Rule 1 should be enabled") + assert.Equal(t, types.PolicyTrafficActionAccept, rule1.Action, "Rule 1 action mismatch") + assert.True(t, rule1.Bidirectional, "Rule 1 should be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolALL, rule1.Protocol, "Rule 1 protocol mismatch") + assert.Len(t, rule1.Sources, 1, "Rule 1 sources length mismatch") + assert.Contains(t, rule1.Sources, groupID1, "Rule 1 should have group1 as source") + assert.Len(t, rule1.Destinations, 1, "Rule 1 destinations length mismatch") + assert.Contains(t, rule1.Destinations, groupID2, "Rule 1 should have group2 as destination") + + rule2 := pol1.Rules[1] + assert.Equal(t, "Rule 2", rule2.Name, "Rule 2 name mismatch") + assert.Equal(t, types.PolicyTrafficActionDrop, rule2.Action, "Rule 2 action mismatch") + assert.False(t, rule2.Bidirectional, "Rule 2 should not be bidirectional") + assert.Equal(t, types.PolicyRuleProtocolTCP, rule2.Protocol, "Rule 2 protocol mismatch") + assert.Len(t, rule2.Ports, 2, "Rule 2 ports length mismatch") + assert.Contains(t, rule2.Ports, "22", "Rule 2 should have port 22") + assert.Contains(t, rule2.Ports, "3389", "Rule 2 should have port 3389") + assert.Len(t, rule2.PortRanges, 1, "Rule 2 port ranges length mismatch") + assert.Equal(t, uint16(8000), rule2.PortRanges[0].Start, "Rule 2 port range start mismatch") + assert.Equal(t, uint16(8999), rule2.PortRanges[0].End, "Rule 2 port range end mismatch") + + // Validate Policy 2 + pol2 := retrievedAccount.Policies[1] + if pol2.ID != policyID2 { + pol2 = retrievedAccount.Policies[0] + } + assert.Equal(t, policyID2, pol2.ID, "Policy 2 ID mismatch") + assert.Equal(t, "Policy 2", pol2.Name, "Policy 2 name mismatch") + assert.False(t, pol2.Enabled, "Policy 2 should be disabled") + require.Len(t, pol2.Rules, 1, "Policy 2 should have 1 rule") + + rule3 := pol2.Rules[0] + assert.Equal(t, "Rule 3", rule3.Name, "Rule 3 name mismatch") + assert.False(t, rule3.Enabled, "Rule 3 should be disabled") + assert.Equal(t, types.PolicyRuleProtocolUDP, rule3.Protocol, "Rule 3 protocol mismatch") + }) + + // ========== VALIDATE ROUTES ========== + t.Run("Routes", func(t *testing.T) { + require.Len(t, retrievedAccount.Routes, 2, "Should have 2 routes") + + // Validate Route 1 + r1, exists := retrievedAccount.Routes[routeID1] + require.True(t, exists, "Route 1 should exist") + assert.Equal(t, "Route 1", r1.Description, "Route 1 description mismatch") + assert.Equal(t, route.IPv4Network, r1.NetworkType, "Route 1 network type mismatch") + assert.Equal(t, peerID1, r1.Peer, "Route 1 peer mismatch") + assert.Empty(t, r1.PeerGroups, "Route 1 peer groups should be empty") + assert.Equal(t, route.NetID("net-id-1"), r1.NetID, "Route 1 net ID mismatch") + assert.True(t, r1.Masquerade, "Route 1 masquerade should be enabled") + assert.Equal(t, 9999, r1.Metric, "Route 1 metric mismatch") + assert.True(t, r1.Enabled, "Route 1 should be enabled") + assert.Len(t, r1.Groups, 1, "Route 1 groups length mismatch") + assert.Contains(t, r1.Groups, groupID1, "Route 1 should have group1") + assert.Len(t, r1.AccessControlGroups, 1, "Route 1 ACL groups length mismatch") + assert.Contains(t, r1.AccessControlGroups, groupID2, "Route 1 should have group2 in ACL") + + // Validate Route 1 Network CIDR + assert.Equal(t, "10.0.0.0/24", r1.Network.String(), "Route 1 network CIDR mismatch") + + // Validate Route 2 + r2, exists := retrievedAccount.Routes[routeID2] + require.True(t, exists, "Route 2 should exist") + assert.Equal(t, "Route 2 (High Availability)", r2.Description, "Route 2 description mismatch") + assert.Empty(t, r2.Peer, "Route 2 peer should be empty") + assert.Len(t, r2.PeerGroups, 1, "Route 2 peer groups length mismatch") + assert.Contains(t, r2.PeerGroups, groupID2, "Route 2 should have group2 as peer group") + assert.False(t, r2.Masquerade, "Route 2 masquerade should be disabled") + assert.Equal(t, 100, r2.Metric, "Route 2 metric mismatch") + assert.Equal(t, "192.168.1.0/24", r2.Network.String(), "Route 2 network CIDR mismatch") + }) + + // ========== VALIDATE NAME SERVER GROUPS ========== + t.Run("NameServerGroups", func(t *testing.T) { + require.Len(t, retrievedAccount.NameServerGroups, 2, "Should have 2 nameserver groups") + + // Validate NS Group 1 + nsg1, exists := retrievedAccount.NameServerGroups[nsGroupID1] + require.True(t, exists, "NS Group 1 should exist") + assert.Equal(t, "NS Group 1", nsg1.Name, "NS Group 1 name mismatch") + assert.Equal(t, "Primary nameservers", nsg1.Description, "NS Group 1 description mismatch") + assert.True(t, nsg1.Enabled, "NS Group 1 should be enabled") + assert.True(t, nsg1.Primary, "NS Group 1 should be primary") + assert.True(t, nsg1.SearchDomainsEnabled, "NS Group 1 search domains should be enabled") + assert.Empty(t, nsg1.AccountID, "NS Group 1 AccountID should be cleared") + + // Validate NS Group 1 NameServers + require.Len(t, nsg1.NameServers, 2, "NS Group 1 should have 2 nameservers") + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), nsg1.NameServers[0].IP, "NS Group 1 nameserver 1 IP mismatch") + assert.Equal(t, nbdns.UDPNameServerType, nsg1.NameServers[0].NSType, "NS Group 1 nameserver 1 type mismatch") + assert.Equal(t, 53, nsg1.NameServers[0].Port, "NS Group 1 nameserver 1 port mismatch") + assert.Equal(t, netip.MustParseAddr("8.8.4.4"), nsg1.NameServers[1].IP, "NS Group 1 nameserver 2 IP mismatch") + + // Validate NS Group 1 Groups and Domains + assert.Len(t, nsg1.Groups, 2, "NS Group 1 groups length mismatch") + assert.Contains(t, nsg1.Groups, groupID1, "NS Group 1 should have group1") + assert.Contains(t, nsg1.Groups, groupID2, "NS Group 1 should have group2") + assert.Len(t, nsg1.Domains, 2, "NS Group 1 domains length mismatch") + assert.Contains(t, nsg1.Domains, "example.com", "NS Group 1 should have example.com domain") + assert.Contains(t, nsg1.Domains, "test.com", "NS Group 1 should have test.com domain") + + // Validate NS Group 2 + nsg2, exists := retrievedAccount.NameServerGroups[nsGroupID2] + require.True(t, exists, "NS Group 2 should exist") + assert.Equal(t, "NS Group 2", nsg2.Name, "NS Group 2 name mismatch") + assert.False(t, nsg2.Enabled, "NS Group 2 should be disabled") + assert.False(t, nsg2.Primary, "NS Group 2 should not be primary") + assert.False(t, nsg2.SearchDomainsEnabled, "NS Group 2 search domains should be disabled") + assert.Len(t, nsg2.NameServers, 1, "NS Group 2 should have 1 nameserver") + assert.Len(t, nsg2.Groups, 0, "NS Group 2 should have empty groups") + assert.Len(t, nsg2.Domains, 0, "NS Group 2 should have empty domains") + }) + + // ========== VALIDATE POSTURE CHECKS ========== + t.Run("PostureChecks", func(t *testing.T) { + require.Len(t, retrievedAccount.PostureChecks, 2, "Should have 2 posture checks") + + // Find posture checks by ID + var pc1, pc2 *posture.Checks + for _, pc := range retrievedAccount.PostureChecks { + if pc.ID == postureCheckID1 { + pc1 = pc + } else if pc.ID == postureCheckID2 { + pc2 = pc + } + } + + // Validate Posture Check 1 + require.NotNil(t, pc1, "Posture check 1 should exist") + assert.Equal(t, "Posture Check 1", pc1.Name, "Posture check 1 name mismatch") + assert.Equal(t, "OS version check", pc1.Description, "Posture check 1 description mismatch") + + // Validate NB Version Check + require.NotNil(t, pc1.Checks.NBVersionCheck, "NB version check should not be nil") + assert.Equal(t, "0.24.0", pc1.Checks.NBVersionCheck.MinVersion, "NB version check min version mismatch") + + // Validate OS Version Check + require.NotNil(t, pc1.Checks.OSVersionCheck, "OS version check should not be nil") + require.NotNil(t, pc1.Checks.OSVersionCheck.Ios, "iOS version check should not be nil") + assert.Equal(t, "16.0", pc1.Checks.OSVersionCheck.Ios.MinVersion, "iOS min version mismatch") + require.NotNil(t, pc1.Checks.OSVersionCheck.Darwin, "Darwin version check should not be nil") + assert.Equal(t, "22.0.0", pc1.Checks.OSVersionCheck.Darwin.MinVersion, "Darwin min version mismatch") + + // Validate Posture Check 2 + require.NotNil(t, pc2, "Posture check 2 should exist") + assert.Equal(t, "Posture Check 2", pc2.Name, "Posture check 2 name mismatch") + + // Validate Geo Location Check + require.NotNil(t, pc2.Checks.GeoLocationCheck, "Geo location check should not be nil") + assert.Equal(t, "allow", pc2.Checks.GeoLocationCheck.Action, "Geo location action mismatch") + assert.Len(t, pc2.Checks.GeoLocationCheck.Locations, 2, "Geo location check should have 2 locations") + assert.Equal(t, "US", pc2.Checks.GeoLocationCheck.Locations[0].CountryCode, "Location 1 country code mismatch") + assert.Equal(t, "San Francisco", pc2.Checks.GeoLocationCheck.Locations[0].CityName, "Location 1 city name mismatch") + assert.Equal(t, "GB", pc2.Checks.GeoLocationCheck.Locations[1].CountryCode, "Location 2 country code mismatch") + assert.Equal(t, "London", pc2.Checks.GeoLocationCheck.Locations[1].CityName, "Location 2 city name mismatch") + + // Validate Peer Network Range Check + require.NotNil(t, pc2.Checks.PeerNetworkRangeCheck, "Peer network range check should not be nil") + assert.Equal(t, "allow", pc2.Checks.PeerNetworkRangeCheck.Action, "Peer network range action mismatch") + assert.Len(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, 2, "Peer network range check should have 2 ranges") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("192.168.0.0/16"), "Should have 192.168.0.0/16 range") + assert.Contains(t, pc2.Checks.PeerNetworkRangeCheck.Ranges, netip.MustParsePrefix("10.0.0.0/8"), "Should have 10.0.0.0/8 range") + }) + + // ========== VALIDATE NETWORKS ========== + t.Run("Networks", func(t *testing.T) { + require.Len(t, retrievedAccount.Networks, 1, "Should have 1 network") + + net1 := retrievedAccount.Networks[0] + assert.Equal(t, networkID1, net1.ID, "Network 1 ID mismatch") + assert.Equal(t, "Network 1", net1.Name, "Network 1 name mismatch") + assert.Equal(t, "Primary network", net1.Description, "Network 1 description mismatch") + }) + + // ========== VALIDATE NETWORK ROUTERS ========== + t.Run("NetworkRouters", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkRouters, 1, "Should have 1 network router") + + router := retrievedAccount.NetworkRouters[0] + assert.Equal(t, routerID1, router.ID, "Router 1 ID mismatch") + assert.Equal(t, networkID1, router.NetworkID, "Router 1 network ID mismatch") + assert.Equal(t, peerID1, router.Peer, "Router 1 peer mismatch") + assert.Empty(t, router.PeerGroups, "Router 1 peer groups should be empty") + assert.True(t, router.Masquerade, "Router 1 masquerade should be enabled") + assert.Equal(t, 100, router.Metric, "Router 1 metric mismatch") + }) + + // ========== VALIDATE NETWORK RESOURCES ========== + t.Run("NetworkResources", func(t *testing.T) { + require.Len(t, retrievedAccount.NetworkResources, 1, "Should have 1 network resource") + + res := retrievedAccount.NetworkResources[0] + assert.Equal(t, resourceID1, res.ID, "Resource 1 ID mismatch") + assert.Equal(t, networkID1, res.NetworkID, "Resource 1 network ID mismatch") + assert.Equal(t, "Resource 1", res.Name, "Resource 1 name mismatch") + assert.Equal(t, "Web server", res.Description, "Resource 1 description mismatch") + assert.Equal(t, netip.MustParsePrefix("192.168.1.100/32"), res.Prefix, "Resource 1 prefix mismatch") + assert.Equal(t, resourceTypes.Host, res.Type, "Resource 1 type mismatch") + }) + + // ========== VALIDATE ONBOARDING ========== + t.Run("Onboarding", func(t *testing.T) { + assert.Equal(t, accountID, retrievedAccount.Onboarding.AccountID, "Onboarding account ID mismatch") + assert.True(t, retrievedAccount.Onboarding.OnboardingFlowPending, "Onboarding flow should be pending") + assert.False(t, retrievedAccount.Onboarding.SignupFormPending, "Signup form should not be pending") + assert.WithinDuration(t, now, retrievedAccount.Onboarding.CreatedAt, time.Second, "Onboarding created at mismatch") + }) + + t.Log("✅ All comprehensive account field validations passed!") +} diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go new file mode 100644 index 000000000..350a1da83 --- /dev/null +++ b/management/server/store/sqlstore_bench_test.go @@ -0,0 +1,951 @@ +package store + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "sort" + "sync" + "testing" + "time" + + "gorm.io/driver/postgres" + "gorm.io/gorm" + "gorm.io/gorm/clause" + + "github.com/jackc/pgx/v5/pgxpool" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/status" +) + +func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Omit("GroupsG"). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload(clause.Associations). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + // we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us + for i, policy := range account.Policies { + var rules []*types.PolicyRule + err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error + if err != nil { + return nil, status.Errorf(status.NotFound, "rule not found") + } + account.Policies[i].Rules = rules + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + account.SetupKeys[key.Key] = key.Copy() + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = peer.Copy() + } + account.PeersG = nil + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + user.PATs[pat.ID] = pat.Copy() + } + account.Users[user.Id] = user.Copy() + } + account.UsersG = nil + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + account.Groups[group.ID] = group.Copy() + } + account.GroupsG = nil + + var groupPeers []types.GroupPeer + s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID). + Find(&groupPeers) + for _, groupPeer := range groupPeers { + if group, ok := account.Groups[groupPeer.GroupID]; ok { + group.Peers = append(group.Peers, groupPeer.PeerID) + } else { + log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID) + } + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = route.Copy() + } + account.RoutesG = nil + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + account.NameServerGroups[ns.ID] = ns.Copy() + } + account.NameServerGroupsG = nil + + return &account, nil +} + +func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("UsersG.PATsG"). // have to be specified as this is nested reference + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("UsersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + account.Users = make(map[string]*types.User, len(account.UsersG)) + for _, user := range account.UsersG { + user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs)) + for _, pat := range user.PATsG { + pat.UserID = "" + user.PATs[pat.ID] = &pat + } + if user.AutoGroups == nil { + user.AutoGroups = []string{} + } + account.Users[user.Id] = &user + user.PATsG = nil + } + account.UsersG = nil + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + return &account, nil +} + +func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) { + config, err := pgxpool.ParseConfig(dsn) + if err != nil { + return nil, fmt.Errorf("unable to parse database config: %w", err) + } + + config.MaxConns = 12 + config.MinConns = 2 + config.MaxConnLifetime = time.Hour + config.HealthCheckPeriod = time.Minute + + pool, err := pgxpool.NewWithConfig(ctx, config) + if err != nil { + return nil, fmt.Errorf("unable to create connection pool: %w", err) + } + + if err := pool.Ping(ctx); err != nil { + pool.Close() + return nil, fmt.Errorf("unable to ping database: %w", err) + } + return pool, nil +} + +func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) { + cleanup, dsn, err := testutil.CreatePostgresTestContainer() + if err != nil { + b.Fatalf("failed to create test container: %v", err) + } + + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + pool, err := connectDBforTest(context.Background(), dsn) + if err != nil { + b.Fatalf("failed to connect database: %v", err) + } + + models := []interface{}{ + &types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, + &types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{}, + &types.Policy{}, &types.PolicyRule{}, &route.Route{}, + &nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{}, + &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &types.AccountOnboarding{}, + } + + for i := len(models) - 1; i >= 0; i-- { + err := db.Migrator().DropTable(models[i]) + if err != nil { + b.Fatalf("failed to drop table: %v", err) + } + } + + err = db.AutoMigrate(models...) + if err != nil { + b.Fatalf("failed to migrate database: %v", err) + } + + store := &SqlStore{ + db: db, + pool: pool, + } + + const ( + accountID = "benchmark-account-id" + numUsers = 20 + numPatsPerUser = 3 + numSetupKeys = 25 + numPeers = 200 + numGroups = 30 + numPolicies = 50 + numRulesPerPolicy = 10 + numRoutes = 40 + numNSGroups = 10 + numPostureChecks = 15 + numNetworks = 5 + numNetworkRouters = 5 + numNetworkResources = 10 + ) + + _, ipNet, _ := net.ParseCIDR("100.64.0.0/10") + acc := types.Account{ + Id: accountID, + CreatedBy: "benchmark-user", + CreatedAt: time.Now(), + Domain: "benchmark.com", + IsDomainPrimaryAccount: true, + Network: &types.Network{ + Identifier: "benchmark-net", + Net: *ipNet, + Serial: 1, + }, + DNSSettings: types.DNSSettings{ + DisabledManagementGroups: []string{"group-disabled-1"}, + }, + Settings: &types.Settings{}, + } + if err := db.Create(&acc).Error; err != nil { + b.Fatalf("create account: %v", err) + } + + var setupKeys []types.SetupKey + for i := 0; i < numSetupKeys; i++ { + setupKeys = append(setupKeys, types.SetupKey{ + Id: fmt.Sprintf("keyid-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("key-%d", i), + Name: fmt.Sprintf("Benchmark Key %d", i), + ExpiresAt: &time.Time{}, + }) + } + if err := db.Create(&setupKeys).Error; err != nil { + b.Fatalf("create setup keys: %v", err) + } + + var peers []nbpeer.Peer + for i := 0; i < numPeers; i++ { + peers = append(peers, nbpeer.Peer{ + ID: fmt.Sprintf("peer-%d", i), + AccountID: accountID, + Key: fmt.Sprintf("peerkey-%d", i), + IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)), + Name: fmt.Sprintf("peer-name-%d", i), + Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()}, + }) + } + if err := db.Create(&peers).Error; err != nil { + b.Fatalf("create peers: %v", err) + } + + for i := 0; i < numUsers; i++ { + userID := fmt.Sprintf("user-%d", i) + user := types.User{Id: userID, AccountID: accountID} + if err := db.Create(&user).Error; err != nil { + b.Fatalf("create user %s: %v", userID, err) + } + + var pats []types.PersonalAccessToken + for j := 0; j < numPatsPerUser; j++ { + pats = append(pats, types.PersonalAccessToken{ + ID: fmt.Sprintf("pat-%d-%d", i, j), + UserID: userID, + Name: fmt.Sprintf("PAT %d for User %d", j, i), + }) + } + if err := db.Create(&pats).Error; err != nil { + b.Fatalf("create pats for user %s: %v", userID, err) + } + } + + var groups []*types.Group + for i := 0; i < numGroups; i++ { + groups = append(groups, &types.Group{ + ID: fmt.Sprintf("group-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Group %d", i), + }) + } + if err := db.Create(&groups).Error; err != nil { + b.Fatalf("create groups: %v", err) + } + + for i := 0; i < numPolicies; i++ { + policyID := fmt.Sprintf("policy-%d", i) + policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true} + if err := db.Create(&policy).Error; err != nil { + b.Fatalf("create policy %s: %v", policyID, err) + } + + var rules []*types.PolicyRule + for j := 0; j < numRulesPerPolicy; j++ { + rules = append(rules, &types.PolicyRule{ + ID: fmt.Sprintf("rule-%d-%d", i, j), + PolicyID: policyID, + Name: fmt.Sprintf("Rule %d for Policy %d", j, i), + Enabled: true, + Protocol: "all", + }) + } + if err := db.Create(&rules).Error; err != nil { + b.Fatalf("create rules for policy %s: %v", policyID, err) + } + } + + var routes []route.Route + for i := 0; i < numRoutes; i++ { + routes = append(routes, route.Route{ + ID: route.ID(fmt.Sprintf("route-%d", i)), + AccountID: accountID, + Description: fmt.Sprintf("Route %d", i), + Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)), + Enabled: true, + }) + } + if err := db.Create(&routes).Error; err != nil { + b.Fatalf("create routes: %v", err) + } + + var nsGroups []nbdns.NameServerGroup + for i := 0; i < numNSGroups; i++ { + nsGroups = append(nsGroups, nbdns.NameServerGroup{ + ID: fmt.Sprintf("nsg-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("NS Group %d", i), + Description: "Benchmark NS Group", + Enabled: true, + }) + } + if err := db.Create(&nsGroups).Error; err != nil { + b.Fatalf("create nsgroups: %v", err) + } + + var postureChecks []*posture.Checks + for i := 0; i < numPostureChecks; i++ { + postureChecks = append(postureChecks, &posture.Checks{ + ID: fmt.Sprintf("pc-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Posture Check %d", i), + }) + } + if err := db.Create(&postureChecks).Error; err != nil { + b.Fatalf("create posture checks: %v", err) + } + + var networks []*networkTypes.Network + for i := 0; i < numNetworks; i++ { + networks = append(networks, &networkTypes.Network{ + ID: fmt.Sprintf("nettype-%d", i), + AccountID: accountID, + Name: fmt.Sprintf("Network Type %d", i), + }) + } + if err := db.Create(&networks).Error; err != nil { + b.Fatalf("create networks: %v", err) + } + + var networkRouters []*routerTypes.NetworkRouter + for i := 0; i < numNetworkRouters; i++ { + networkRouters = append(networkRouters, &routerTypes.NetworkRouter{ + ID: fmt.Sprintf("router-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Peer: peers[i%numPeers].ID, + }) + } + if err := db.Create(&networkRouters).Error; err != nil { + b.Fatalf("create network routers: %v", err) + } + + var networkResources []*resourceTypes.NetworkResource + for i := 0; i < numNetworkResources; i++ { + networkResources = append(networkResources, &resourceTypes.NetworkResource{ + ID: fmt.Sprintf("resource-%d", i), + AccountID: accountID, + NetworkID: networks[i%numNetworks].ID, + Name: fmt.Sprintf("Resource %d", i), + }) + } + if err := db.Create(&networkResources).Error; err != nil { + b.Fatalf("create network resources: %v", err) + } + + onboarding := types.AccountOnboarding{ + AccountID: accountID, + OnboardingFlowPending: true, + } + if err := db.Create(&onboarding).Error; err != nil { + b.Fatalf("create onboarding: %v", err) + } + + return store, cleanup, accountID +} + +func BenchmarkGetAccount(b *testing.B) { + store, cleanup, accountID := setupBenchmarkDB(b) + defer cleanup() + ctx := context.Background() + b.ResetTimer() + b.ReportAllocs() + b.Run("old", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountSlow(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountSlow failed: %v", err) + } + } + }) + b.Run("gorm opt", func(b *testing.B) { + for range b.N { + _, err := store.GetAccountGormOpt(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountFast failed: %v", err) + } + } + }) + b.Run("raw", func(b *testing.B) { + for range b.N { + _, err := store.GetAccount(ctx, accountID) + if err != nil { + b.Fatalf("GetAccountPureSQL failed: %v", err) + } + } + }) + store.pool.Close() +} + +func TestAccountEquivalence(t *testing.T) { + store, cleanup, accountID := setupBenchmarkDB(t) + defer cleanup() + ctx := context.Background() + + type getAccountFunc func(context.Context, string) (*types.Account, error) + + tests := []struct { + name string + expectedF getAccountFunc + actualF getAccountFunc + }{ + {"old vs new", store.GetAccountSlow, store.GetAccountGormOpt}, + {"old vs raw", store.GetAccountSlow, store.GetAccount}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + expected, errOld := tt.expectedF(ctx, accountID) + assert.NoError(t, errOld, "expected function should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := tt.actualF(ctx, accountID) + assert.NoError(t, errNew, "actual function should not return an error") + assert.NotNil(t, actual, "actual should not be nil") + testAccountEquivalence(t, expected, actual) + }) + } + + expected, errOld := store.GetAccountSlow(ctx, accountID) + assert.NoError(t, errOld, "GetAccountSlow should not return an error") + assert.NotNil(t, expected, "expected should not be nil") + + actual, errNew := store.GetAccount(ctx, accountID) + assert.NoError(t, errNew, "GetAccount (new) should not return an error") + assert.NotNil(t, actual, "actual should not be nil") +} + +func testAccountEquivalence(t *testing.T, expected, actual *types.Account) { + assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal") + assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal") + assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second") + assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal") + assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal") + assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal") + assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal") + assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal") + assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal") + + assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements") + for key, oldVal := range expected.SetupKeys { + newVal, ok := actual.SetupKeys[key] + assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key) + } + + assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements") + for key, oldVal := range expected.Peers { + newVal, ok := actual.Peers[key] + assert.True(t, ok, "Peer with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements") + for key, oldUser := range expected.Users { + newUser, ok := actual.Users[key] + assert.True(t, ok, "User with ID '%s' should exist in new account", key) + + assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key) + for patKey, oldPAT := range oldUser.PATs { + newPAT, patOk := newUser.PATs[patKey] + assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key) + assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key) + } + + oldUser.PATs = nil + newUser.PATs = nil + assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key) + } + + assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements") + for key, oldVal := range expected.Groups { + newVal, ok := actual.Groups[key] + assert.True(t, ok, "Group with ID '%s' should exist in new account", key) + sort.Strings(oldVal.Peers) + sort.Strings(newVal.Peers) + assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements") + for key, oldVal := range expected.Routes { + newVal, ok := actual.Routes[key] + assert.True(t, ok, "Route with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key) + } + + assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements") + for key, oldVal := range expected.NameServerGroups { + newVal, ok := actual.NameServerGroups[key] + assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key) + assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key) + } + + assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements") + sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID }) + sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID }) + for i := range expected.Policies { + sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID }) + sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID }) + assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID) + } + + assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements") + sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID }) + sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID }) + for i := range expected.PostureChecks { + assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID) + } + + assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements") + sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID }) + sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID }) + for i := range expected.Networks { + assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID) + } + + assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements") + sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID }) + sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID }) + for i := range expected.NetworkRouters { + assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID) + } + + assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements") + sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID }) + sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID }) + for i := range expected.NetworkResources { + assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID) + } +} + +func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 12) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + users, err := s.getUsers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.UsersG = users + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var userIDs []string + for _, u := range account.UsersG { + userIDs = append(userIDs, u.Id) + } + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(3) + errChan = make(chan error, 3) + + var pats []types.PersonalAccessToken + go func() { + defer wg.Done() + var err error + pats, err = s.getPersonalAccessTokens(ctx, userIDs) + if err != nil { + errChan <- err + } + }() + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + patsByUserID := make(map[string][]*types.PersonalAccessToken) + for i := range pats { + pat := &pats[i] + patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat) + pat.UserID = "" + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User, len(account.UsersG)) + for i := range account.UsersG { + user := &account.UsersG[i] + user.PATs = make(map[string]*types.PersonalAccessToken) + if userPats, ok := patsByUserID[user.Id]; ok { + for j := range userPats { + pat := userPats[j] + user.PATs[pat.ID] = pat + } + } + account.Users[user.Id] = user + } + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index 21b660d96..007e2b739 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -468,6 +468,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine) closeConnection := func() { cleanup() store.Close(ctx) + if store.pool != nil { + store.pool.Close() + } } return store, closeConnection, nil @@ -487,12 +490,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv) } - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err) } dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB, _ := db.DB() + if sqlDB != nil { + sqlDB.Close() + } + if err != nil { return nil, nil, err } @@ -519,12 +528,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv) } - db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + db, err := openDBWithRetry(dsn, kind, 5) if err != nil { return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err) } + sqlDB, err := db.DB() + if err != nil { + return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err) + } + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(1) + dsn, cleanup, err := createRandomDB(dsn, db, kind) + + sqlDB.Close() + if err != nil { return nil, nil, err } @@ -537,6 +556,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine return store, cleanup, nil } +func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) { + var db *gorm.DB + var err error + + for i := range maxRetries { + switch engine { + case types.PostgresStoreEngine: + db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{}) + case types.MysqlStoreEngine: + db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{}) + } + + if err == nil { + return db, nil + } + + if i < maxRetries-1 { + waitTime := time.Duration(100*(i+1)) * time.Millisecond + time.Sleep(waitTime) + } + } + + return nil, err +} + func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) { dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_")) @@ -544,21 +588,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func( return "", nil, fmt.Errorf("failed to create database: %v", err) } - var err error + originalDSN := dsn + cleanup := func() { + var dropDB *gorm.DB + var err error + switch engine { case types.PostgresStoreEngine: - err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error + dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error + case types.MysqlStoreEngine: - // err = killMySQLConnections(dsn, dbName) - err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error + dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{ + SkipDefaultTransaction: true, + PrepareStmt: false, + }) + if err != nil { + log.Errorf("failed to connect for dropping database %s: %v", dbName, err) + return + } + defer func() { + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.Close() + } + }() + + if sqlDB, _ := dropDB.DB(); sqlDB != nil { + sqlDB.SetMaxOpenConns(1) + sqlDB.SetMaxIdleConns(0) + sqlDB.SetConnMaxLifetime(time.Second) + } + + err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error } + if err != nil { log.Errorf("failed to drop database %s: %v", dbName, err) - panic(err) } - sqlDB, _ := db.DB() - _ = sqlDB.Close() } return replaceDBName(dsn, dbName), cleanup, nil diff --git a/management/server/types/account.go b/management/server/types/account.go index f830023c7..8797e1fa3 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -8,6 +8,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/hashicorp/go-multierror" @@ -87,6 +88,13 @@ type Account struct { NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` + + NetworkMapCache *NetworkMapBuilder `gorm:"-"` + nmapInitOnce *sync.Once `gorm:"-"` +} + +func (a *Account) InitOnce() { + a.nmapInitOnce = &sync.Once{} } // this class is used by gorm only @@ -257,7 +265,6 @@ func (a *Account) GetPeerNetworkMap( metrics *telemetry.AccountManagerMetrics, ) *NetworkMap { start := time.Now() - peer := a.Peers[peerID] if peer == nil { return &NetworkMap{ @@ -301,7 +308,7 @@ func (a *Account) GetPeerNetworkMap( if dnsManagementStatus { var zones []nbdns.CustomZone if peersCustomZone.Domain != "" { - records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnect) + records := filterZoneRecordsForPeers(peer, peersCustomZone, peersToConnectIncludingRouters, expiredPeers) zones = append(zones, nbdns.CustomZone{ Domain: peersCustomZone.Domain, Records: records, @@ -890,6 +897,8 @@ func (a *Account) Copy() *Account { NetworkRouters: networkRouters, NetworkResources: networkResources, Onboarding: a.Onboarding, + NetworkMapCache: a.NetworkMapCache, + nmapInitOnce: a.nmapInitOnce, } } @@ -1049,14 +1058,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer rules := make([]*FirewallRule, 0) peers := make([]*nbpeer.Peer, 0) - all, err := a.GetGroupAll() - if err != nil { - log.WithContext(ctx).Errorf("failed to get group all: %v", err) - all = &Group{} - } - return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) { - isAll := (len(all.Peers) - 1) == len(groupPeers) for _, peer := range groupPeers { if peer == nil { continue @@ -1075,10 +1077,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer Protocol: string(rule.Protocol), } - if isAll { - fr.PeerIP = "0.0.0.0" - } - ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) + fr.Protocol + fr.Action + strings.Join(rule.Ports, ",") if _, ok := rulesExists[ruleID]; ok { @@ -1682,7 +1680,7 @@ func peerSupportsPortRanges(peerVer string) bool { } // filterZoneRecordsForPeers filters DNS records to only include peers to connect. -func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect []*nbpeer.Peer) []nbdns.SimpleRecord { +func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, peersToConnect, expiredPeers []*nbpeer.Peer) []nbdns.SimpleRecord { filteredRecords := make([]nbdns.SimpleRecord, 0, len(customZone.Records)) peerIPs := make(map[string]struct{}) @@ -1693,6 +1691,10 @@ func filterZoneRecordsForPeers(peer *nbpeer.Peer, customZone nbdns.CustomZone, p peerIPs[peerToConnect.IP.String()] = struct{}{} } + for _, expiredPeer := range expiredPeers { + peerIPs[expiredPeer.IP.String()] = struct{}{} + } + for _, record := range customZone.Records { if _, exists := peerIPs[record.RData]; exists { filteredRecords = append(filteredRecords, record) diff --git a/management/server/types/account_test.go b/management/server/types/account_test.go index cd221b590..32538933a 100644 --- a/management/server/types/account_test.go +++ b/management/server/types/account_test.go @@ -845,6 +845,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { peer *nbpeer.Peer customZone nbdns.CustomZone peersToConnect []*nbpeer.Peer + expiredPeers []*nbpeer.Peer expectedRecords []nbdns.SimpleRecord }{ { @@ -857,6 +858,7 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { }, }, peersToConnect: []*nbpeer.Peer{}, + expiredPeers: []*nbpeer.Peer{}, peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, @@ -890,7 +892,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { } return peers }(), - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: func() []nbdns.SimpleRecord { var records []nbdns.SimpleRecord for _, i := range []int{1, 5, 10, 25, 50, 75, 100} { @@ -924,7 +927,8 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {ID: "peer1", IP: net.ParseIP("10.0.0.1"), DNSLabel: "peer1", ExtraDNSLabels: []string{"peer1-alt", "peer1-backup"}}, {ID: "peer2", IP: net.ParseIP("10.0.0.2"), DNSLabel: "peer2", ExtraDNSLabels: []string{"peer2-service"}}, }, - peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expiredPeers: []*nbpeer.Peer{}, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, expectedRecords: []nbdns.SimpleRecord{ {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, {Name: "peer1-alt.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, @@ -934,11 +938,35 @@ func Test_FilterZoneRecordsForPeers(t *testing.T) { {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, }, }, + { + name: "expired peers are included in DNS entries", + customZone: nbdns.CustomZone{ + Domain: "netbird.cloud.", + Records: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "peer2.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.2"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, + peersToConnect: []*nbpeer.Peer{ + {ID: "peer1", IP: net.ParseIP("10.0.0.1")}, + }, + expiredPeers: []*nbpeer.Peer{ + {ID: "expired-peer", IP: net.ParseIP("10.0.0.99")}, + }, + peer: &nbpeer.Peer{ID: "router", IP: net.ParseIP("10.0.0.100")}, + expectedRecords: []nbdns.SimpleRecord{ + {Name: "peer1.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.1"}, + {Name: "expired-peer.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.99"}, + {Name: "router.netbird.cloud", Type: int(dns.TypeA), Class: nbdns.DefaultClass, TTL: 300, RData: "10.0.0.100"}, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect) + result := filterZoneRecordsForPeers(tt.peer, tt.customZone, tt.peersToConnect, tt.expiredPeers) assert.Equal(t, len(tt.expectedRecords), len(result)) assert.ElementsMatch(t, tt.expectedRecords, result) }) diff --git a/management/server/types/holder.go b/management/server/types/holder.go new file mode 100644 index 000000000..3996db2b6 --- /dev/null +++ b/management/server/types/holder.go @@ -0,0 +1,43 @@ +package types + +import ( + "context" + "sync" +) + +type Holder struct { + mu sync.RWMutex + accounts map[string]*Account +} + +func NewHolder() *Holder { + return &Holder{ + accounts: make(map[string]*Account), + } +} + +func (h *Holder) GetAccount(id string) *Account { + h.mu.RLock() + defer h.mu.RUnlock() + return h.accounts[id] +} + +func (h *Holder) AddAccount(account *Account) { + h.mu.Lock() + defer h.mu.Unlock() + h.accounts[account.Id] = account +} + +func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) { + h.mu.Lock() + defer h.mu.Unlock() + if acc, ok := h.accounts[id]; ok { + return acc, nil + } + account, err := accGetter(context.Background(), id) + if err != nil { + return nil, err + } + h.accounts[id] = account + return account, nil +} diff --git a/management/server/types/networkmap.go b/management/server/types/networkmap.go new file mode 100644 index 000000000..c1099726f --- /dev/null +++ b/management/server/types/networkmap.go @@ -0,0 +1,58 @@ +package types + +import ( + "context" + + nbdns "github.com/netbirdio/netbird/dns" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" +) + +func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) { + if a.NetworkMapCache != nil { + return + } + a.nmapInitOnce.Do(func() { + a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers) + }) +} + +func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} + +func (a *Account) GetPeerNetworkMapExp( + ctx context.Context, + peerID string, + peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, + metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + a.initNetworkMapBuilder(validatedPeers) + return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics) +} + +func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerAddedIncremental(peerId) +} + +func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error { + if a.NetworkMapCache == nil { + return nil + } + return a.NetworkMapCache.OnPeerDeleted(peerId) +} + +func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) { + if a.NetworkMapCache == nil { + return + } + a.NetworkMapCache.UpdatePeer(peer) +} + +func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) { + a.initNetworkMapBuilder(validatedPeers) +} diff --git a/management/server/types/networkmap_golden_test.go b/management/server/types/networkmap_golden_test.go new file mode 100644 index 000000000..d85aaabb2 --- /dev/null +++ b/management/server/types/networkmap_golden_test.go @@ -0,0 +1,1069 @@ +package types_test + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/netip" + "os" + "path/filepath" + "slices" + "sort" + "testing" + "time" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + networkTypes "github.com/netbirdio/netbird/management/server/networks/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/route" +) + +// update flag is used to update the golden file. +// example: go test ./... -v -update +// var update = flag.Bool("update", false, "update golden files") + +const ( + numPeers = 100 + devGroupID = "group-dev" + opsGroupID = "group-ops" + allGroupID = "group-all" + routeID = route.ID("route-main") + routeHA1ID = route.ID("route-ha-1") + routeHA2ID = route.ID("route-ha-2") + policyIDDevOps = "policy-dev-ops" + policyIDAll = "policy-all" + policyIDPosture = "policy-posture" + policyIDDrop = "policy-drop" + postureCheckID = "posture-check-ver" + networkResourceID = "res-database" + networkID = "net-database" + networkRouterID = "router-database" + nameserverGroupID = "ns-group-main" + testingPeerID = "peer-60" // A peer from the "dev" group, should receive the most detailed map. + expiredPeerID = "peer-98" // This peer will be online but with an expired session. + offlinePeerID = "peer-99" // This peer will be completely offline. + routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network. + testAccountID = "account-golden-test" +) + +func TestGetPeerNetworkMap_Golden(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json") + + t.Log("Update golden file...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file") +} + +func BenchmarkGetPeerNetworkMap(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + b.ResetTimer() + b.Run("old builder", func(b *testing.B) { + for range b.N { + for _, peerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + b.ResetTimer() + b.Run("new builder", func(b *testing.B) { + for range b.N { + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + for _, peerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json") + + t.Log("Update golden file with new peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newPeerID := "peer-new-101" + newPeerIP := net.IP{100, 64, 1, 1} + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: newPeerIP, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newPeerID] = newPeer + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = append(devGroup.Peers, newPeerID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newPeerID) + } + + validatedPeersMap[newPeerID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newPeerID) + require.NoError(t, err, "error adding peer to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json") + t.Log("Update golden file with OnPeerAdded...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newPeerID := "peer-new-101" + newPeer := &nbpeer.Peer{ + ID: newPeerID, + IP: net.IP{100, 64, 1, 1}, + Key: fmt.Sprintf("key-%s", newPeerID), + DNSLabel: "peernew101", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + } + + account.Peers[newPeerID] = newPeer + account.Groups[devGroupID].Peers = append(account.Groups[devGroupID].Peers, newPeerID) + account.Groups[allGroupID].Peers = append(account.Groups[allGroupID].Peers, newPeerID) + validatedPeersMap[newPeerID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json") + + t.Log("Update golden file with new router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerAddedIncremental(newRouterID) + require.NoError(t, err, "error adding router to cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json") + + t.Log("Update golden file with OnPeerAdded router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + newRouterID := "peer-new-router-102" + newRouterIP := net.IP{100, 64, 1, 2} + newRouter := &nbpeer.Peer{ + ID: newRouterID, + IP: newRouterIP, + Key: fmt.Sprintf("key-%s", newRouterID), + DNSLabel: "newrouter102", + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", + Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"}, + LastLogin: func() *time.Time { t := time.Now(); return &t }(), + } + + account.Peers[newRouterID] = newRouter + + if opsGroup, exists := account.Groups[opsGroupID]; exists { + opsGroup.Peers = append(opsGroup.Peers, newRouterID) + } + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = append(allGroup.Peers, newRouterID) + } + + newRoute := &route.Route{ + ID: route.ID("route-new-router"), + Network: netip.MustParsePrefix("172.16.0.0/24"), + Peer: newRouter.Key, + PeerID: newRouterID, + Description: "Route from new router", + Enabled: true, + PeerGroups: []string{opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + AccountID: account.Id, + } + account.Routes[newRoute.ID] = newRoute + + validatedPeersMap[newRouterID] = struct{}{} + + b.ResetTimer() + b.Run("old builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after add", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerAddedIncremental(newRouterID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedPeerID := "peer-25" // peer from devs group + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedPeerID := "peer-25" // devs group peer + + delete(account.Peers, deletedPeerID) + + if devGroup, exists := account.Groups[devGroupID]; exists { + devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + if allGroup, exists := account.Groups[allGroupID]; exists { + allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool { + return id == deletedPeerID + }) + } + + delete(validatedPeersMap, deletedPeerID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedPeerID) + require.NoError(t, err, "error deleting peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json") + t.Log("Update golden file with OnPeerDeleted...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + resourcePolicies := account.GetResourcePoliciesMap() + routers := account.GetResourceRoutersMap() + + networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err, "error marshaling network map to JSON") + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json") + + t.Log("Update golden file with deleted peer...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err, "error reading golden file") + + require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file") +} + +func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) { + account := createTestAccountWithEntities() + + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + if peerID == offlinePeerID { + continue + } + validatedPeersMap[peerID] = struct{}{} + } + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + deletedRouterID := "peer-75" // router peer + + var affectedRoute *route.Route + for _, r := range account.Routes { + if r.PeerID == deletedRouterID { + affectedRoute = r + break + } + } + require.NotNil(t, affectedRoute, "Router peer should have a route") + + for _, group := range account.Groups { + group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool { + return id == deletedRouterID + }) + } + for routeID, r := range account.Routes { + if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID { + delete(account.Routes, routeID) + } + } + delete(account.Peers, deletedRouterID) + delete(validatedPeersMap, deletedRouterID) + + if account.Network != nil { + account.Network.Serial++ + } + + err := builder.OnPeerDeleted(deletedRouterID) + require.NoError(t, err, "error deleting routing peer from cache") + + networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + + normalizeAndSortNetworkMap(networkMap) + + jsonData, err := json.MarshalIndent(networkMap, "", " ") + require.NoError(t, err) + + goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json") + + t.Log("Update golden file with deleted router...") + err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755) + require.NoError(t, err) + err = os.WriteFile(goldenFilePath, jsonData, 0644) + require.NoError(t, err) + + expectedJSON, err := os.ReadFile(goldenFilePath) + require.NoError(t, err) + + require.JSONEq(t, string(expectedJSON), string(jsonData), + "network map after deleting router does not match golden file") +} + +func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) { + account := createTestAccountWithEntities() + ctx := context.Background() + validatedPeersMap := make(map[string]struct{}) + var peerIDs []string + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + validatedPeersMap[peerID] = struct{}{} + peerIDs = append(peerIDs, peerID) + } + + deletedPeerID := "peer-25" + + delete(account.Peers, deletedPeerID) + account.Groups[devGroupID].Peers = slices.DeleteFunc(account.Groups[devGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + account.Groups[allGroupID].Peers = slices.DeleteFunc(account.Groups[allGroupID].Peers, func(id string) bool { + return id == deletedPeerID + }) + delete(validatedPeersMap, deletedPeerID) + + builder := types.NewNetworkMapBuilder(account, validatedPeersMap) + + b.ResetTimer() + b.Run("old builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + for _, testingPeerID := range peerIDs { + _ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil) + } + } + }) + + b.ResetTimer() + b.Run("new builder after delete", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = builder.OnPeerDeleted(deletedPeerID) + for _, testingPeerID := range peerIDs { + _ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil) + } + } + }) +} + +func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) { + for _, peer := range networkMap.Peers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + for _, peer := range networkMap.OfflinePeers { + if peer.Status != nil { + peer.Status.LastSeen = time.Time{} + } + peer.LastLogin = &time.Time{} + } + + sort.Slice(networkMap.Peers, func(i, j int) bool { return networkMap.Peers[i].ID < networkMap.Peers[j].ID }) + sort.Slice(networkMap.OfflinePeers, func(i, j int) bool { return networkMap.OfflinePeers[i].ID < networkMap.OfflinePeers[j].ID }) + sort.Slice(networkMap.Routes, func(i, j int) bool { return networkMap.Routes[i].ID < networkMap.Routes[j].ID }) + + sort.Slice(networkMap.FirewallRules, func(i, j int) bool { + r1, r2 := networkMap.FirewallRules[i], networkMap.FirewallRules[j] + if r1.PeerIP != r2.PeerIP { + return r1.PeerIP < r2.PeerIP + } + if r1.Protocol != r2.Protocol { + return r1.Protocol < r2.Protocol + } + if r1.Direction != r2.Direction { + return r1.Direction < r2.Direction + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + return r1.Port < r2.Port + }) + + sort.Slice(networkMap.RoutesFirewallRules, func(i, j int) bool { + r1, r2 := networkMap.RoutesFirewallRules[i], networkMap.RoutesFirewallRules[j] + if r1.RouteID != r2.RouteID { + return r1.RouteID < r2.RouteID + } + if r1.Action != r2.Action { + return r1.Action < r2.Action + } + if r1.Destination != r2.Destination { + return r1.Destination < r2.Destination + } + if len(r1.SourceRanges) > 0 && len(r2.SourceRanges) > 0 { + if r1.SourceRanges[0] != r2.SourceRanges[0] { + return r1.SourceRanges[0] < r2.SourceRanges[0] + } + } + return r1.Port < r2.Port + }) + + for _, ranges := range networkMap.RoutesFirewallRules { + sort.Slice(ranges.SourceRanges, func(i, j int) bool { + return ranges.SourceRanges[i] < ranges.SourceRanges[j] + }) + } +} + +func createTestAccountWithEntities() *types.Account { + peers := make(map[string]*nbpeer.Peer) + devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{} + + for i := range numPeers { + peerID := fmt.Sprintf("peer-%d", i) + ip := net.IP{100, 64, 0, byte(i + 1)} + wtVersion := "0.25.0" + if i%2 == 0 { + wtVersion = "0.40.0" + } + + p := &nbpeer.Peer{ + ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1), + Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()}, + UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"}, + } + + if peerID == expiredPeerID { + p.LoginExpirationEnabled = true + pastTimestamp := time.Now().Add(-2 * time.Hour) + p.LastLogin = &pastTimestamp + } + + peers[peerID] = p + allGroupPeers = append(allGroupPeers, peerID) + if i < numPeers/2 { + devGroupPeers = append(devGroupPeers, peerID) + } else { + opsGroupPeers = append(opsGroupPeers, peerID) + } + + } + + groups := map[string]*types.Group{ + allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers}, + devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers}, + opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers}, + } + + policies := []*types.Policy{ + { + ID: policyIDAll, Name: "Default-Allow", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDAll, Name: "Allow All", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{allGroupID}, Destinations: []string{allGroupID}, + }}, + }, + { + ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolTCP, Bidirectional: false, + PortRanges: []types.RulePortRange{{Start: 8080, End: 8090}}, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true, + Rules: []*types.PolicyRule{{ + ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: types.PolicyTrafficActionDrop, + Protocol: types.PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true, + Sources: []string{devGroupID}, Destinations: []string{opsGroupID}, + }}, + }, + { + ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true, + SourcePostureChecks: []string{postureCheckID}, + Rules: []*types.PolicyRule{{ + ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: types.PolicyTrafficActionAccept, + Protocol: types.PolicyRuleProtocolALL, Bidirectional: true, + Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID}, + }}, + }, + } + + routes := map[route.ID]*route.Route{ + routeID: { + ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"), + Peer: peers["peer-75"].Key, + PeerID: "peer-75", + Description: "Route to internal resource", Enabled: true, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{devGroupID}, + }, + routeHA1ID: { + ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-80"].Key, + PeerID: "peer-80", + Description: "HA Route 1", Enabled: true, Metric: 1000, + PeerGroups: []string{allGroupID}, + Groups: []string{allGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + routeHA2ID: { + ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"), + Peer: peers["peer-90"].Key, + PeerID: "peer-90", + Description: "HA Route 2", Enabled: true, Metric: 900, + PeerGroups: []string{devGroupID, opsGroupID}, + Groups: []string{devGroupID, opsGroupID}, + AccessControlGroups: []string{allGroupID}, + }, + } + + account := &types.Account{ + Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes, + Network: &types.Network{ + Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1, + }, + DNSSettings: types.DNSSettings{DisabledManagementGroups: []string{opsGroupID}}, + NameServerGroups: map[string]*dns.NameServerGroup{ + nameserverGroupID: { + ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID}, + NameServers: []dns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: dns.UDPNameServerType, Port: 53}}, + }, + }, + PostureChecks: []*posture.Checks{ + {ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{ + NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"}, + }}, + }, + NetworkResources: []*resourceTypes.NetworkResource{ + {ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"}, + }, + Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}}, + NetworkRouters: []*routerTypes.NetworkRouter{ + {ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID}, + }, + Settings: &types.Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour}, + } + + for _, p := range account.Policies { + p.AccountID = account.Id + } + for _, r := range account.Routes { + r.AccountID = account.Id + } + + return account +} diff --git a/management/server/types/networkmapbuilder.go b/management/server/types/networkmapbuilder.go new file mode 100644 index 000000000..58f1bfa30 --- /dev/null +++ b/management/server/types/networkmapbuilder.go @@ -0,0 +1,1932 @@ +package types + +import ( + "context" + "fmt" + "slices" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" + "golang.org/x/exp/maps" + + nbdns "github.com/netbirdio/netbird/dns" + resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" + routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/netbirdio/netbird/route" +) + +const ( + allPeers = "0.0.0.0" + fw = "fw:" + rfw = "route-fw:" + nr = "network-resource-" +) + +type NetworkMapCache struct { + globalRoutes map[route.ID]*route.Route + globalRules map[string]*FirewallRule //ruleId + globalRouteRules map[string]*RouteFirewallRule //ruleId + globalPeers map[string]*nbpeer.Peer + + groupToPeers map[string][]string + peerToGroups map[string][]string + policyToRules map[string][]*PolicyRule //policyId + groupToPolicies map[string][]*Policy + groupToRoutes map[string][]*route.Route + peerToRoutes map[string][]*route.Route + + peerACLs map[string]*PeerACLView + peerRoutes map[string]*PeerRoutesView + peerDNS map[string]*nbdns.Config + + resourceRouters map[string]map[string]*routerTypes.NetworkRouter + resourcePolicies map[string][]*Policy + + globalResources map[string]*resourceTypes.NetworkResource // resourceId + + acgToRoutes map[string]map[route.ID]*RouteOwnerInfo // routeID -> owner info + noACGRoutes map[route.ID]*RouteOwnerInfo + + mu sync.RWMutex +} + +type RouteOwnerInfo struct { + PeerID string + RouteID route.ID +} + +type PeerACLView struct { + ConnectedPeerIDs []string + FirewallRuleIDs []string +} + +type PeerRoutesView struct { + OwnRouteIDs []route.ID + NetworkResourceIDs []route.ID + InheritedRouteIDs []route.ID + RouteFirewallRuleIDs []string +} + +type NetworkMapBuilder struct { + account atomic.Pointer[Account] + cache *NetworkMapCache + validatedPeers map[string]struct{} +} + +func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder { + builder := &NetworkMapBuilder{ + cache: &NetworkMapCache{ + globalRoutes: make(map[route.ID]*route.Route), + globalRules: make(map[string]*FirewallRule), + globalRouteRules: make(map[string]*RouteFirewallRule), + globalPeers: make(map[string]*nbpeer.Peer), + groupToPeers: make(map[string][]string), + peerToGroups: make(map[string][]string), + policyToRules: make(map[string][]*PolicyRule), + groupToPolicies: make(map[string][]*Policy), + groupToRoutes: make(map[string][]*route.Route), + peerToRoutes: make(map[string][]*route.Route), + peerACLs: make(map[string]*PeerACLView), + peerRoutes: make(map[string]*PeerRoutesView), + peerDNS: make(map[string]*nbdns.Config), + globalResources: make(map[string]*resourceTypes.NetworkResource), + acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo), + noACGRoutes: make(map[route.ID]*RouteOwnerInfo), + }, + validatedPeers: make(map[string]struct{}), + } + builder.account.Store(account) + maps.Copy(builder.validatedPeers, validatedPeers) + + builder.initialBuild(account) + + return builder +} + +func (b *NetworkMapBuilder) initialBuild(account *Account) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + start := time.Now() + + b.buildGlobalIndexes(account) + + resourceRouters := account.GetResourceRoutersMap() + resourcePolicies := account.GetResourcePoliciesMap() + b.cache.resourceRouters = resourceRouters + b.cache.resourcePolicies = resourcePolicies + + for peerID := range account.Peers { + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + } + + log.Debugf("NetworkMapBuilder: Initial build completed in %v for account %s", time.Since(start), account.Id) +} + +func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) { + clear(b.cache.globalPeers) + clear(b.cache.groupToPeers) + clear(b.cache.peerToGroups) + clear(b.cache.policyToRules) + clear(b.cache.groupToPolicies) + clear(b.cache.globalRoutes) + clear(b.cache.globalRules) + clear(b.cache.globalRouteRules) + clear(b.cache.globalResources) + clear(b.cache.groupToRoutes) + clear(b.cache.peerToRoutes) + clear(b.cache.acgToRoutes) + clear(b.cache.noACGRoutes) + + maps.Copy(b.cache.globalPeers, account.Peers) + + for groupID, group := range account.Groups { + peersCopy := make([]string, len(group.Peers)) + copy(peersCopy, group.Peers) + b.cache.groupToPeers[groupID] = peersCopy + + for _, peerID := range group.Peers { + b.cache.peerToGroups[peerID] = append(b.cache.peerToGroups[peerID], groupID) + } + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + b.cache.policyToRules[policy.ID] = policy.Rules + + affectedGroups := make(map[string]struct{}) + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + for _, groupID := range rule.Sources { + affectedGroups[groupID] = struct{}{} + } + for _, groupID := range rule.Destinations { + affectedGroups[groupID] = struct{}{} + } + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + groupId := rule.SourceResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.SourceResource.ID] = append(b.cache.peerToGroups[rule.SourceResource.ID], groupId) + } + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + groupId := rule.DestinationResource.ID + affectedGroups[groupId] = struct{}{} + b.cache.peerToGroups[rule.DestinationResource.ID] = append(b.cache.peerToGroups[rule.DestinationResource.ID], groupId) + } + } + + for groupID := range affectedGroups { + b.cache.groupToPolicies[groupID] = append(b.cache.groupToPolicies[groupID], policy) + } + } + + for _, resource := range account.NetworkResources { + if !resource.Enabled { + continue + } + b.cache.globalResources[resource.ID] = resource + } + + for _, r := range account.Routes { + if !r.Enabled { + continue + } + for _, groupID := range r.PeerGroups { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } +} + +func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) { + peer := account.GetPeer(peerID) + if peer == nil { + return + } + + allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers) + + isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer) + + var emptyExpiredPeers []*nbpeer.Peer + finalAllPeers := b.addNetworksRoutingPeers( + networkResourcesRoutes, + peer, + allPotentialPeers, + emptyExpiredPeers, + isRouter, + sourcePeers, + ) + + view := &PeerACLView{ + ConnectedPeerIDs: make([]string, 0, len(finalAllPeers)), + FirewallRuleIDs: make([]string, 0, len(firewallRules)), + } + + for _, p := range finalAllPeers { + view.ConnectedPeerIDs = append(view.ConnectedPeerIDs, p.ID) + } + + for _, rule := range firewallRules { + ruleID := b.generateFirewallRuleID(rule) + view.FirewallRuleIDs = append(view.FirewallRuleIDs, ruleID) + b.cache.globalRules[ruleID] = rule + } + + b.cache.peerACLs[peerID] = view +} + +func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer, + validatedPeersMap map[string]struct{}, +) ([]*nbpeer.Peer, []*FirewallRule) { + ctx := context.Background() + + peerID := peer.ID + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + rulesExists := make(map[string]struct{}) + peersExists := make(map[string]struct{}) + fwRules := make([]*FirewallRule, 0) + peers := make([]*nbpeer.Peer, 0) + + for _, group := range peerGroups { + policies := b.cache.groupToPolicies[group] + for _, policy := range policies { + if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid { + continue + } + rules := b.cache.policyToRules[policy.ID] + for _, rule := range rules { + var sourcePeers, destinationPeers []*nbpeer.Peer + var peerInSources, peerInDestinations bool + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peerInSources = rule.SourceResource.ID == peerID + } else { + peerInSources = b.isPeerInGroupscached(rule.Sources, peerGroupsMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peerInDestinations = rule.DestinationResource.ID == peerID + } else { + peerInDestinations = b.isPeerInGroupscached(rule.Destinations, peerGroupsMap) + } + + if !peerInSources && !peerInDestinations { + continue + } + + if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" { + peer := account.GetPeer(rule.SourceResource.ID) + if peer != nil { + sourcePeers = []*nbpeer.Peer{peer} + } + } else { + sourcePeers = b.getPeersFromGroupscached(account, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) + } + + if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" { + peer := account.GetPeer(rule.DestinationResource.ID) + if peer != nil { + destinationPeers = []*nbpeer.Peer{peer} + } + } else { + destinationPeers = b.getPeersFromGroupscached(account, rule.Destinations, peerID, nil, validatedPeersMap) + } + + if rule.Bidirectional { + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + + if peerInSources { + b.generateResourcescached( + account, rule, destinationPeers, FirewallRuleDirectionOUT, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + + if peerInDestinations { + b.generateResourcescached( + account, rule, sourcePeers, FirewallRuleDirectionIN, + peer, &peers, &fwRules, peersExists, rulesExists, + ) + } + } + } + } + + return peers, fwRules +} + +func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool { + for _, groupID := range groupIDs { + if _, exists := peerGroupsMap[groupID]; exists { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs []string, + excludePeerID string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, +) []*nbpeer.Peer { + ctx := context.Background() + uniquePeers := make(map[string]*nbpeer.Peer) + + for _, groupID := range groupIDs { + peerIDs := b.cache.groupToPeers[groupID] + for _, peerID := range peerIDs { + if peerID == excludePeerID { + continue + } + + if _, ok := validatedPeersMap[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + if len(postureChecksIDs) > 0 { + if !account.validatePostureChecksOnPeer(ctx, postureChecksIDs, peerID) { + continue + } + } + + uniquePeers[peerID] = peer + } + } + + result := make([]*nbpeer.Peer, 0, len(uniquePeers)) + for _, peer := range uniquePeers { + result = append(result, peer) + } + + return result +} + +func (b *NetworkMapBuilder) generateResourcescached( + account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer, + peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{}, +) { + isAll := false + if allGroup, err := account.GetGroupAll(); err == nil { + isAll = (len(allGroup.Peers) - 1) == len(groupPeers) + } + + for _, peer := range groupPeers { + if peer == nil { + continue + } + if _, ok := peersExists[peer.ID]; !ok { + *peers = append(*peers, peer) + peersExists[peer.ID] = struct{}{} + } + + fr := FirewallRule{ + PolicyID: rule.ID, + PeerIP: peer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + + if isAll { + fr.PeerIP = allPeers + } + + var s strings.Builder + s.WriteString(rule.ID) + s.WriteString(fr.PeerIP) + s.WriteString(strconv.Itoa(direction)) + s.WriteString(fr.Protocol) + s.WriteString(fr.Action) + s.WriteString(strings.Join(rule.Ports, ",")) + + ruleID := s.String() + + if _, ok := rulesExists[ruleID]; ok { + continue + } + rulesExists[ruleID] = struct{}{} + + if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 { + *rules = append(*rules, &fr) + continue + } + + *rules = append(*rules, expandPortsAndRanges(fr, rule, targetPeer)...) + } +} + +func (b *NetworkMapBuilder) getNetworkResourcesForPeer(account *Account, peer *nbpeer.Peer) (bool, []*route.Route, map[string]struct{}) { + ctx := context.Background() + peerID := peer.ID + + var isRoutingPeer bool + var routes []*route.Route + allSourcePeers := make(map[string]struct{}) + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, resource := range b.cache.globalResources { + + networkRoutingPeers := b.cache.resourceRouters[resource.NetworkID] + resourcePolicies := b.cache.resourcePolicies[resource.ID] + if len(resourcePolicies) == 0 { + continue + } + + isRouterForThisResource := false + + if networkRoutingPeers != nil { + if router, ok := networkRoutingPeers[peerID]; ok && router.Enabled { + isRoutingPeer = true + isRouterForThisResource = true + if rt := b.createNetworkResourceRoutes(resource, peerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + + hasAccessAsClient := false + if !isRouterForThisResource { + for _, policy := range resourcePolicies { + if b.isPeerInGroupscached(policy.SourceGroups(), peerGroupsMap) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + hasAccessAsClient = true + break + } + } + } + } + + if hasAccessAsClient && networkRoutingPeers != nil { + for routerPeerID, router := range networkRoutingPeers { + if router.Enabled { + if rt := b.createNetworkResourceRoutes(resource, routerPeerID, router, resourcePolicies); rt != nil { + routes = append(routes, rt) + } + } + } + } + + if isRouterForThisResource { + for _, policy := range resourcePolicies { + var peersWithAccess []*nbpeer.Peer + if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" { + peersWithAccess = []*nbpeer.Peer{peer} + } else { + peersWithAccess = b.getPeersFromGroupscached(account, policy.SourceGroups(), "", policy.SourcePostureChecks, b.validatedPeers) + } + for _, p := range peersWithAccess { + allSourcePeers[p.ID] = struct{}{} + } + } + } + } + + return isRoutingPeer, routes, allSourcePeers +} + +func (b *NetworkMapBuilder) createNetworkResourceRoutes( + resource *resourceTypes.NetworkResource, routerPeerID string, + router *routerTypes.NetworkRouter, resourcePolicies []*Policy, +) *route.Route { + if len(resourcePolicies) > 0 { + peer := b.cache.globalPeers[routerPeerID] + if peer != nil { + return resource.ToRoute(peer, router) + } + } + return nil +} + +func (b *NetworkMapBuilder) addNetworksRoutingPeers( + networkResourcesRoutes []*route.Route, peer *nbpeer.Peer, peersToConnect []*nbpeer.Peer, + expiredPeers []*nbpeer.Peer, isRouter bool, sourcePeers map[string]struct{}, +) []*nbpeer.Peer { + + networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes)) + for _, r := range networkResourcesRoutes { + networkRoutesPeers[r.PeerID] = struct{}{} + } + + delete(sourcePeers, peer.ID) + delete(networkRoutesPeers, peer.ID) + + for _, existingPeer := range peersToConnect { + delete(sourcePeers, existingPeer.ID) + delete(networkRoutesPeers, existingPeer.ID) + } + for _, expPeer := range expiredPeers { + delete(sourcePeers, expPeer.ID) + delete(networkRoutesPeers, expPeer.ID) + } + + missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers)) + if isRouter { + for p := range sourcePeers { + missingPeers[p] = struct{}{} + } + } + for p := range networkRoutesPeers { + missingPeers[p] = struct{}{} + } + + for p := range missingPeers { + if missingPeer := b.cache.globalPeers[p]; missingPeer != nil { + peersToConnect = append(peersToConnect, missingPeer) + } + } + + return peersToConnect +} + +func (b *NetworkMapBuilder) buildPeerRoutesView(account *Account, peerID string) { + ctx := context.Background() + peer := account.GetPeer(peerID) + if peer == nil { + return + } + resourcePolicies := b.cache.resourcePolicies + + view := &PeerRoutesView{ + OwnRouteIDs: make([]route.ID, 0), + NetworkResourceIDs: make([]route.ID, 0), + RouteFirewallRuleIDs: make([]string, 0), + } + + enabledRoutes, disabledRoutes := b.getRoutingPeerRoutes(peerID) + for _, rt := range enabledRoutes { + if rt.PeerID != "" && rt.PeerID != peerID { + if b.cache.globalPeers[rt.PeerID] == nil { + continue + } + } + + view.OwnRouteIDs = append(view.OwnRouteIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + aclView := b.cache.peerACLs[peerID] + if aclView != nil { + peerRoutesMembership := make(LookupMap) + for _, r := range append(enabledRoutes, disabledRoutes...) { + peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{} + } + + peerGroups := b.cache.peerToGroups[peerID] + peerGroupsMap := make(LookupMap) + for _, groupID := range peerGroups { + peerGroupsMap[groupID] = struct{}{} + } + + for _, aclPeerID := range aclView.ConnectedPeerIDs { + if aclPeerID == peerID { + continue + } + activeRoutes, _ := b.getRoutingPeerRoutes(aclPeerID) + groupFilteredRoutes := account.filterRoutesByGroups(activeRoutes, peerGroupsMap) + haFilteredRoutes := account.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership) + + for _, inheritedRoute := range haFilteredRoutes { + view.InheritedRouteIDs = append(view.InheritedRouteIDs, inheritedRoute.ID) + b.cache.globalRoutes[inheritedRoute.ID] = inheritedRoute + } + } + } + + _, networkResourcesRoutes, _ := b.getNetworkResourcesForPeer(account, peer) + + for _, rt := range networkResourcesRoutes { + view.NetworkResourceIDs = append(view.NetworkResourceIDs, rt.ID) + b.cache.globalRoutes[rt.ID] = rt + } + + allRoutes := slices.Concat(enabledRoutes, networkResourcesRoutes) + b.updateACGIndexForPeer(peerID, allRoutes) + + routeFirewallRules := b.getPeerRoutesFirewallRules(account, peerID, b.validatedPeers) + for _, rule := range routeFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + + if len(networkResourcesRoutes) > 0 { + networkResourceFirewallRules := account.GetPeerNetworkResourceFirewallRules(ctx, peer, b.validatedPeers, networkResourcesRoutes, resourcePolicies) + for _, rule := range networkResourceFirewallRules { + ruleID := b.generateRouteFirewallRuleID(rule) + view.RouteFirewallRuleIDs = append(view.RouteFirewallRuleIDs, ruleID) + b.cache.globalRouteRules[ruleID] = rule + } + } + + b.cache.peerRoutes[peerID] = view +} + +func (b *NetworkMapBuilder) updateACGIndexForPeer(peerID string, routes []*route.Route) { + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == peerID { + delete(b.cache.noACGRoutes, routeID) + } + } + + for _, rt := range routes { + if !rt.Enabled { + continue + } + + if len(rt.AccessControlGroups) == 0 { + b.cache.noACGRoutes[rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } else { + for _, acg := range rt.AccessControlGroups { + if b.cache.acgToRoutes[acg] == nil { + b.cache.acgToRoutes[acg] = make(map[route.ID]*RouteOwnerInfo) + } + + b.cache.acgToRoutes[acg][rt.ID] = &RouteOwnerInfo{ + PeerID: peerID, + RouteID: rt.ID, + } + } + } + } +} + +func (b *NetworkMapBuilder) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) { + peer := b.cache.globalPeers[peerID] + if peer == nil { + return enabledRoutes, disabledRoutes + } + + seenRoute := make(map[route.ID]struct{}) + + takeRoute := func(r *route.Route, id string) { + if _, ok := seenRoute[r.ID]; ok { + return + } + seenRoute[r.ID] = struct{}{} + + if r.Enabled { + // maybe here is some mess - here we store peer key (see comment below) + r.Peer = peer.Key + enabledRoutes = append(enabledRoutes, r) + return + } + disabledRoutes = append(disabledRoutes, r) + } + + peerGroups := b.cache.peerToGroups[peerID] + for _, groupID := range peerGroups { + groupRoutes := b.cache.groupToRoutes[groupID] + for _, r := range groupRoutes { + newPeerRoute := r.Copy() + // and here we store peer ID - this logic is taken from original account.getRoutingPeerRoutes + newPeerRoute.Peer = peerID + newPeerRoute.PeerGroups = nil + newPeerRoute.ID = route.ID(string(r.ID) + ":" + peerID) + takeRoute(newPeerRoute, peerID) + } + } + for _, r := range b.cache.peerToRoutes[peerID] { + takeRoute(r.Copy(), peerID) + } + return enabledRoutes, disabledRoutes +} + +func (b *NetworkMapBuilder) getPeerRoutesFirewallRules(account *Account, peerID string, validatedPeersMap map[string]struct{}) []*RouteFirewallRule { + routesFirewallRules := make([]*RouteFirewallRule, 0) + + enabledRoutes, _ := b.getRoutingPeerRoutes(peerID) + for _, route := range enabledRoutes { + if len(route.AccessControlGroups) == 0 { + defaultPermit := getDefaultPermit(route) + routesFirewallRules = append(routesFirewallRules, defaultPermit...) + continue + } + + distributionPeers := b.getDistributionGroupsPeers(route) + + for _, accessGroup := range route.AccessControlGroups { + policies := b.getAllRoutePoliciesFromGroups([]string{accessGroup}) + + rules := b.getRouteFirewallRules(peerID, policies, route, validatedPeersMap, distributionPeers, account) + routesFirewallRules = append(routesFirewallRules, rules...) + } + } + + return routesFirewallRules +} + +func (b *NetworkMapBuilder) getDistributionGroupsPeers(route *route.Route) map[string]struct{} { + distPeers := make(map[string]struct{}) + for _, id := range route.Groups { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + distPeers[pID] = struct{}{} + } + } + return distPeers +} + +func (b *NetworkMapBuilder) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy { + routePolicies := make(map[string]*Policy) + + for _, groupID := range accessControlGroups { + candidatePolicies := b.cache.groupToPolicies[groupID] + + for _, policy := range candidatePolicies { + if _, found := routePolicies[policy.ID]; found { + continue + } + policyRules := b.cache.policyToRules[policy.ID] + for _, rule := range policyRules { + if slices.Contains(rule.Destinations, groupID) { + routePolicies[policy.ID] = policy + break + } + } + } + } + + return maps.Values(routePolicies) +} + +func (b *NetworkMapBuilder) getRouteFirewallRules( + peerID string, policies []*Policy, route *route.Route, validatedPeersMap map[string]struct{}, + distributionPeers map[string]struct{}, account *Account, +) []*RouteFirewallRule { + ctx := context.Background() + var fwRules []*RouteFirewallRule + for _, policy := range policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + rulePeers := b.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers, validatedPeersMap, account) + + rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN) + fwRules = append(fwRules, rules...) + } + } + return fwRules +} + +func (b *NetworkMapBuilder) getRulePeers( + rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}, + validatedPeersMap map[string]struct{}, account *Account, +) []*nbpeer.Peer { + distPeersWithPolicy := make(map[string]struct{}) + + for _, id := range rule.Sources { + groupPeers := b.cache.groupToPeers[id] + if groupPeers == nil { + continue + } + + for _, pID := range groupPeers { + if pID == peerID { + continue + } + _, distPeer := distributionPeers[pID] + _, valid := validatedPeersMap[pID] + + if distPeer && valid && account.validatePostureChecksOnPeer(context.Background(), postureChecks, pID) { + distPeersWithPolicy[pID] = struct{}{} + } + } + } + + distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy)) + for pID := range distPeersWithPolicy { + peer := b.cache.globalPeers[pID] + if peer == nil { + continue + } + distributionGroupPeers = append(distributionGroupPeers, peer) + } + return distributionGroupPeers +} + +func (b *NetworkMapBuilder) buildPeerDNSView(account *Account, peerID string) { + peerGroups := b.cache.peerToGroups[peerID] + checkGroups := make(map[string]struct{}, len(peerGroups)) + for _, groupID := range peerGroups { + checkGroups[groupID] = struct{}{} + } + + dnsManagementStatus := b.getPeerDNSManagementStatus(account, checkGroups) + dnsConfig := &nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + dnsConfig.NameServerGroups = b.getPeerNSGroups(account, peerID, checkGroups) + } + + b.cache.peerDNS[peerID] = dnsConfig +} + +func (b *NetworkMapBuilder) getPeerDNSManagementStatus(account *Account, checkGroups map[string]struct{}) bool { + + enabled := true + for _, groupID := range account.DNSSettings.DisabledManagementGroups { + _, found := checkGroups[groupID] + if found { + enabled = false + break + } + } + return enabled +} + +func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, checkGroups map[string]struct{}) []*nbdns.NameServerGroup { + var peerNSGroups []*nbdns.NameServerGroup + + for _, nsGroup := range account.NameServerGroups { + if !nsGroup.Enabled { + continue + } + for _, gID := range nsGroup.Groups { + _, found := checkGroups[gID] + if found { + peer := b.cache.globalPeers[peerID] + if !peerIsNameserver(peer, nsGroup) { + peerNSGroups = append(peerNSGroups, nsGroup.Copy()) + break + } + } + } + } + + return peerNSGroups +} + +func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) { + b.account.Store(account) +} + +func (b *NetworkMapBuilder) GetPeerNetworkMap( + ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone, + validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics, +) *NetworkMap { + start := time.Now() + account := b.account.Load() + + peer := account.GetPeer(peerID) + if peer == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + b.cache.mu.RLock() + defer b.cache.mu.RUnlock() + + aclView := b.cache.peerACLs[peerID] + routesView := b.cache.peerRoutes[peerID] + dnsConfig := b.cache.peerDNS[peerID] + + if aclView == nil || routesView == nil || dnsConfig == nil { + return &NetworkMap{Network: account.Network.Copy()} + } + + nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers) + + if metrics != nil { + objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules)) + metrics.CountNetworkMapObjects(objectCount) + metrics.CountGetPeerNetworkMapDuration(time.Since(start)) + + if objectCount > 5000 { + log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache", + account.Id, objectCount) + } + } + + return nm +} + +func (b *NetworkMapBuilder) assembleNetworkMap( + account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView, + dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{}, +) *NetworkMap { + + var peersToConnect []*nbpeer.Peer + var expiredPeers []*nbpeer.Peer + + for _, peerID := range aclView.ConnectedPeerIDs { + if _, ok := validatedPeers[peerID]; !ok { + continue + } + + peer := b.cache.globalPeers[peerID] + if peer == nil { + continue + } + + expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) + if account.Settings.PeerLoginExpirationEnabled && expired { + expiredPeers = append(expiredPeers, peer) + } else { + peersToConnect = append(peersToConnect, peer) + } + } + + var routes []*route.Route + allRouteIDs := slices.Concat(routesView.OwnRouteIDs, routesView.NetworkResourceIDs, routesView.InheritedRouteIDs) + + for _, routeID := range allRouteIDs { + if route := b.cache.globalRoutes[routeID]; route != nil { + routes = append(routes, route) + } + } + + var firewallRules []*FirewallRule + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil { + firewallRules = append(firewallRules, rule) + } + } + + var routesFirewallRules []*RouteFirewallRule + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + routesFirewallRules = append(routesFirewallRules, rule) + } + } + + finalDNSConfig := *dnsConfig + if finalDNSConfig.ServiceEnable && customZone.Domain != "" { + var zones []nbdns.CustomZone + records := filterZoneRecordsForPeers(peer, customZone, peersToConnect, expiredPeers) + zones = append(zones, nbdns.CustomZone{ + Domain: customZone.Domain, + Records: records, + }) + finalDNSConfig.CustomZones = zones + } + + return &NetworkMap{ + Peers: peersToConnect, + Network: account.Network.Copy(), + Routes: routes, + DNSConfig: finalDNSConfig, + OfflinePeers: expiredPeers, + FirewallRules: firewallRules, + RoutesFirewallRules: routesFirewallRules, + } +} + +func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string { + var s strings.Builder + s.WriteString(fw) + s.WriteString(rule.PolicyID) + s.WriteRune(':') + s.WriteString(rule.PeerIP) + s.WriteRune(':') + s.WriteString(strconv.Itoa(rule.Direction)) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(rule.Port) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.PortRange.Start))) + s.WriteRune('-') + s.WriteString(strconv.Itoa(int(rule.PortRange.End))) + return s.String() +} + +func (b *NetworkMapBuilder) generateRouteFirewallRuleID(rule *RouteFirewallRule) string { + var s strings.Builder + s.WriteString(rfw) + s.WriteString(string(rule.RouteID)) + s.WriteRune(':') + s.WriteString(rule.Destination) + s.WriteRune(':') + s.WriteString(rule.Action) + s.WriteRune(':') + s.WriteString(strings.Join(rule.SourceRanges, ",")) + s.WriteRune(':') + s.WriteString(rule.Protocol) + s.WriteRune(':') + s.WriteString(strconv.Itoa(int(rule.Port))) + return s.String() +} + +func (b *NetworkMapBuilder) isPeerInGroups(groupIDs []string, peerGroups []string) bool { + for _, groupID := range groupIDs { + if slices.Contains(peerGroups, groupID) { + return true + } + } + return false +} + +func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool { + for _, r := range account.Routes { + if !r.Enabled { + continue + } + + if r.PeerID == peerID { + return true + } + + if peer := b.cache.globalPeers[peerID]; peer != nil { + if r.Peer == peer.Key && r.PeerID == "" { + return true + } + } + } + + routers := account.GetResourceRoutersMap() + for _, networkRouters := range routers { + if router, exists := networkRouters[peerID]; exists && router.Enabled { + return true + } + } + + return false +} + +type ViewDelta struct { + AddedPeerIDs []string + RemovedPeerIDs []string + AddedRuleIDs []string + RemovedRuleIDs []string +} + +func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error { + tt := time.Now() + account := b.account.Load() + peer := account.GetPeer(peerID) + if peer == nil { + return fmt.Errorf("peer %s not found in account", peerID) + } + + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String()) + + b.validatedPeers[peerID] = struct{}{} + + b.cache.globalPeers[peerID] = peer + + peerGroups := b.updateIndexesForNewPeer(account, peerID) + + b.buildPeerACLView(account, peerID) + b.buildPeerRoutesView(account, peerID) + b.buildPeerDNSView(account, peerID) + + log.Debugf("NetworkMapBuilder: Adding peer %s to cache, views took %s", peerID, time.Since(tt)) + + b.incrementalUpdateAffectedPeers(account, peerID, peerGroups) + + log.Debugf("NetworkMapBuilder: Added peer %s to cache, took %s", peerID, time.Since(tt)) + + return nil +} + +func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID string) []string { + peerGroups := make([]string, 0) + + for groupID, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + if !slices.Contains(b.cache.groupToPeers[groupID], peerID) { + b.cache.groupToPeers[groupID] = append(b.cache.groupToPeers[groupID], peerID) + } + peerGroups = append(peerGroups, groupID) + } + } + + b.cache.peerToGroups[peerID] = peerGroups + + for _, r := range account.Routes { + if !r.Enabled || b.cache.globalRoutes[r.ID] != nil { + continue + } + for _, groupID := range r.PeerGroups { + if !slices.Contains(b.cache.groupToRoutes[groupID], r) { + b.cache.groupToRoutes[groupID] = append(b.cache.groupToRoutes[groupID], r) + } + } + if r.Peer != "" { + if peer, ok := b.cache.globalPeers[r.Peer]; ok { + if !slices.Contains(b.cache.peerToRoutes[peer.ID], r) { + b.cache.peerToRoutes[peer.ID] = append(b.cache.peerToRoutes[peer.ID], r) + } + } + } + b.cache.globalRoutes[r.ID] = r + } + + return peerGroups +} + +func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) { + updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups) + + if b.isPeerRouter(account, newPeerID) { + affectedByRoutes := b.findPeersAffectedByNewRouter(account, newPeerID, peerGroups) + for affectedPeerID := range affectedByRoutes { + if affectedPeerID == newPeerID { + continue + } + if _, exists := updates[affectedPeerID]; !exists { + updates[affectedPeerID] = &PeerUpdateDelta{ + PeerID: affectedPeerID, + RebuildRoutesView: true, + } + } else { + updates[affectedPeerID].RebuildRoutesView = true + } + } + } + + for affectedPeerID, delta := range updates { + b.applyDeltaToPeer(account, affectedPeerID, delta) + } +} + +func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} { + affected := make(map[string]struct{}) + enabledRoutes, _ := b.getRoutingPeerRoutes(newRouterID) + + for _, route := range enabledRoutes { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + + for _, peerGroupID := range route.PeerGroups { + if peers := b.cache.groupToPeers[peerGroupID]; peers != nil { + for _, peerID := range peers { + if peerID != newRouterID { + affected[peerID] = struct{}{} + } + } + } + } + } + + for _, route := range account.Routes { + if !route.Enabled { + continue + } + + routerInPeerGroups := false + for _, peerGroupID := range route.PeerGroups { + if slices.Contains(routerGroups, peerGroupID) { + routerInPeerGroups = true + break + } + } + + if routerInPeerGroups { + for _, distGroupID := range route.Groups { + if peers := b.cache.groupToPeers[distGroupID]; peers != nil { + for _, peerID := range peers { + affected[peerID] = struct{}{} + } + } + } + } + } + + return affected +} + +func (b *NetworkMapBuilder) calculateIncrementalUpdates(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta { + updates := make(map[string]*PeerUpdateDelta) + ctx := context.Background() + + groupAllLn := 0 + if allGroup, err := account.GetGroupAll(); err == nil { + groupAllLn = len(allGroup.Peers) - 1 + } + + newPeer := b.cache.globalPeers[newPeerID] + if newPeer == nil { + return updates + } + + for _, policy := range account.Policies { + if !policy.Enabled { + continue + } + + for _, rule := range policy.Rules { + if !rule.Enabled { + continue + } + + peerInSources := b.isPeerInGroups(rule.Sources, peerGroups) + peerInDestinations := b.isPeerInGroups(rule.Destinations, peerGroups) + + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + + if rule.Bidirectional { + if peerInSources { + b.addUpdateForPeersInGroups(updates, rule.Destinations, newPeerID, rule, FirewallRuleDirectionOUT, groupAllLn) + } + if peerInDestinations { + b.addUpdateForPeersInGroups(updates, rule.Sources, newPeerID, rule, FirewallRuleDirectionIN, groupAllLn) + } + } + } + } + + b.calculateRouteFirewallUpdates(newPeerID, newPeer, peerGroups, updates) + + b.calculateNetworkResourceFirewallUpdates(ctx, account, newPeerID, newPeer, peerGroups, updates) + + b.calculateNewRouterNetworkResourceUpdates(ctx, account, newPeerID, updates) + + return updates +} + +func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates( + ctx context.Context, account *Account, newPeerID string, + updates map[string]*PeerUpdateDelta, +) { + resourceRouters := b.cache.resourceRouters + + for networkID, routers := range resourceRouters { + router, isRouter := routers[newPeerID] + if !isRouter || !router.Enabled { + continue + } + + for _, resource := range b.cache.globalResources { + if resource.NetworkID != networkID { + continue + } + + policies := b.cache.resourcePolicies[resource.ID] + if len(policies) == 0 { + continue + } + + peersWithAccess := make(map[string]struct{}) + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + groupPeers := b.cache.groupToPeers[sourceGroup] + for _, peerID := range groupPeers { + if peerID == newPeerID { + continue + } + + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID) { + peersWithAccess[peerID] = struct{}{} + } + } + } + } + + for peerID := range peersWithAccess { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + } + updates[peerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } + } +} + +func (b *NetworkMapBuilder) calculateRouteFirewallUpdates( + newPeerID string, newPeer *nbpeer.Peer, + peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + processedPeerRoutes := make(map[string]map[route.ID]struct{}) + + for routeID, info := range b.cache.noACGRoutes { + if info.PeerID == newPeerID { + continue + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + + for _, acg := range peerGroups { + routeInfos := b.cache.acgToRoutes[acg] + if routeInfos == nil { + continue + } + + for routeID, info := range routeInfos { + if info.PeerID == newPeerID { + continue + } + + if processedRoutes, exists := processedPeerRoutes[info.PeerID]; exists { + if _, processed := processedRoutes[routeID]; processed { + continue + } + } + + b.addRouteFirewallUpdate(updates, info.PeerID, string(routeID), newPeer.IP.String()) + + if processedPeerRoutes[info.PeerID] == nil { + processedPeerRoutes[info.PeerID] = make(map[route.ID]struct{}) + } + processedPeerRoutes[info.PeerID][routeID] = struct{}{} + } + } +} + +func (b *NetworkMapBuilder) addRouteFirewallUpdate( + updates map[string]*PeerUpdateDelta, peerID string, + routeID string, sourceIP string, +) { + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + UpdateRouteFirewallRules: make([]*RouteFirewallRuleUpdate, 0), + } + updates[peerID] = delta + } + + for _, existing := range delta.UpdateRouteFirewallRules { + if existing.RuleID == routeID && existing.AddSourceIP == sourceIP { + return + } + } + + delta.UpdateRouteFirewallRules = append(delta.UpdateRouteFirewallRules, &RouteFirewallRuleUpdate{ + RuleID: routeID, + AddSourceIP: sourceIP, + }) +} + +func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates( + ctx context.Context, account *Account, newPeerID string, + newPeer *nbpeer.Peer, peerGroups []string, updates map[string]*PeerUpdateDelta, +) { + for _, resource := range b.cache.globalResources { + resourcePolicies := b.cache.resourcePolicies + resourceRouters := b.cache.resourceRouters + + policies := resourcePolicies[resource.ID] + peerHasAccess := false + + for _, policy := range policies { + if !policy.Enabled { + continue + } + + sourceGroups := policy.SourceGroups() + for _, sourceGroup := range sourceGroups { + if slices.Contains(peerGroups, sourceGroup) { + if account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, newPeerID) { + peerHasAccess = true + break + } + } + } + + if peerHasAccess { + break + } + } + + if !peerHasAccess { + continue + } + + networkRouters := resourceRouters[resource.NetworkID] + for routerPeerID, router := range networkRouters { + if !router.Enabled || routerPeerID == newPeerID { + continue + } + + delta := updates[routerPeerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: routerPeerID, + } + updates[routerPeerID] = delta + } + + if delta.AddConnectedPeer == "" { + delta.AddConnectedPeer = newPeerID + } + + delta.RebuildRoutesView = true + } + } +} + +type PeerUpdateDelta struct { + PeerID string + AddConnectedPeer string + AddFirewallRules []*FirewallRuleDelta + AddRoutes []route.ID + UpdateRouteFirewallRules []*RouteFirewallRuleUpdate + UpdateDNS bool + RebuildRoutesView bool +} +type FirewallRuleDelta struct { + Rule *FirewallRule + RuleID string + Direction int +} + +type RouteFirewallRuleUpdate struct { + RuleID string + AddSourceIP string +} + +func (b *NetworkMapBuilder) addUpdateForPeersInGroups( + updates map[string]*PeerUpdateDelta, groupIDs []string, newPeerID string, + rule *PolicyRule, direction int, allGroupLn int, +) { + for _, groupID := range groupIDs { + peers := b.cache.groupToPeers[groupID] + cnt := 0 + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + cnt++ + } + all := false + if allGroupLn > 0 && cnt == allGroupLn { + all = true + } + newPeer := b.cache.globalPeers[newPeerID] + fr := &FirewallRule{ + PolicyID: rule.ID, + PeerIP: newPeer.IP.String(), + Direction: direction, + Action: string(rule.Action), + Protocol: string(rule.Protocol), + } + for _, peerID := range peers { + if peerID == newPeerID { + continue + } + if _, ok := b.validatedPeers[peerID]; !ok { + continue + } + delta := updates[peerID] + if delta == nil { + delta = &PeerUpdateDelta{ + PeerID: peerID, + AddConnectedPeer: newPeerID, + AddFirewallRules: make([]*FirewallRuleDelta, 0), + } + updates[peerID] = delta + } + + if all { + fr.PeerIP = allPeers + } + + if len(rule.Ports) > 0 || len(rule.PortRanges) > 0 { + expandedRules := expandPortsAndRanges(*fr, rule, b.cache.globalPeers[peerID]) + for _, expandedRule := range expandedRules { + ruleID := b.generateFirewallRuleID(expandedRule) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: expandedRule, + RuleID: ruleID, + Direction: direction, + }) + } + } else { + ruleID := b.generateFirewallRuleID(fr) + delta.AddFirewallRules = append(delta.AddFirewallRules, &FirewallRuleDelta{ + Rule: fr, + RuleID: ruleID, + Direction: direction, + }) + } + } + } +} + +func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) { + if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) { + aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) + } + + for _, ruleDelta := range delta.AddFirewallRules { + b.cache.globalRules[ruleDelta.RuleID] = ruleDelta.Rule + + if !slices.Contains(aclView.FirewallRuleIDs, ruleDelta.RuleID) { + aclView.FirewallRuleIDs = append(aclView.FirewallRuleIDs, ruleDelta.RuleID) + } + } + } + } + + if delta.RebuildRoutesView { + b.buildPeerRoutesView(account, peerID) + } else if len(delta.UpdateRouteFirewallRules) > 0 { + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + b.updateRouteFirewallRules(routesView, delta.UpdateRouteFirewallRules) + } + } + + if delta.UpdateDNS { + b.buildPeerDNSView(account, peerID) + } +} + +func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView, updates []*RouteFirewallRuleUpdate) { + for _, update := range updates { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + rule := b.cache.globalRouteRules[ruleID] + if rule == nil { + continue + } + + if string(rule.RouteID) == update.RuleID { + sourceIP := update.AddSourceIP + + if strings.Contains(sourceIP, ":") { + sourceIP += "/128" // IPv6 + } else { + sourceIP += "/32" // IPv4 + } + + if !slices.Contains(rule.SourceRanges, sourceIP) { + rule.SourceRanges = append(rule.SourceRanges, sourceIP) + } + break + } + } + } +} + +func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + + account := b.account.Load() + + deletedPeer := b.cache.globalPeers[peerID] + if deletedPeer == nil { + return fmt.Errorf("peer %s not found in cache", peerID) + } + + deletedPeerKey := deletedPeer.Key + peerGroups := b.cache.peerToGroups[peerID] + peerIP := deletedPeer.IP.String() + + log.Debugf("NetworkMapBuilder: Deleting peer %s (IP: %s) from cache", peerID, peerIP) + + delete(b.validatedPeers, peerID) + + routesToDelete := []route.ID{} + + for routeID, r := range account.Routes { + if r.Peer != deletedPeerKey && r.PeerID != peerID { + continue + } + if len(r.PeerGroups) == 0 { + routesToDelete = append(routesToDelete, routeID) + continue + } + newPeerAssigned := false + for _, groupID := range r.PeerGroups { + candidatePeerIDs := b.cache.groupToPeers[groupID] + for _, candidatePeerID := range candidatePeerIDs { + if candidatePeerID == peerID { + continue + } + if candidatePeer := b.cache.globalPeers[candidatePeerID]; candidatePeer != nil { + r.Peer = candidatePeer.Key + r.PeerID = candidatePeerID + newPeerAssigned = true + break + } + } + if newPeerAssigned { + break + } + } + + if !newPeerAssigned { + routesToDelete = append(routesToDelete, routeID) + } + } + + for _, routeID := range routesToDelete { + delete(account.Routes, routeID) + } + + delete(b.cache.peerACLs, peerID) + delete(b.cache.peerRoutes, peerID) + delete(b.cache.peerDNS, peerID) + + delete(b.cache.globalPeers, peerID) + + for acg, routeMap := range b.cache.acgToRoutes { + for routeID, info := range routeMap { + if info.PeerID == peerID { + delete(routeMap, routeID) + } + } + if len(routeMap) == 0 { + delete(b.cache.acgToRoutes, acg) + } + } + + for _, groupID := range peerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + b.cache.groupToPeers[groupID] = slices.DeleteFunc(peers, func(id string) bool { + return id == peerID + }) + } + } + delete(b.cache.peerToGroups, peerID) + + affectedPeers := make(map[string]struct{}) + + for _, r := range account.Routes { + for _, groupID := range r.Groups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + + for _, groupID := range r.PeerGroups { + if peers := b.cache.groupToPeers[groupID]; peers != nil { + for _, p := range peers { + affectedPeers[p] = struct{}{} + } + } + } + } + + for affectedPeerID := range affectedPeers { + if affectedPeerID == peerID { + continue + } + b.buildPeerRoutesView(account, affectedPeerID) + } + + peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP) + for affectedPeerID, updates := range peerDeletionUpdates { + b.applyDeletionUpdates(affectedPeerID, updates) + } + + b.cleanupUnusedRules() + + log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers)) + + return nil +} + +func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL( + deletedPeerID string, + peerIP string, +) map[string]*PeerDeletionUpdate { + + affected := make(map[string]*PeerDeletionUpdate) + + for peerID, aclView := range b.cache.peerACLs { + if peerID == deletedPeerID { + continue + } + + if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) { + continue + } + if affected[peerID] == nil { + affected[peerID] = &PeerDeletionUpdate{ + RemovePeerID: deletedPeerID, + PeerIP: peerIP, + } + } + + for _, ruleID := range aclView.FirewallRuleIDs { + if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP { + affected[peerID].RemoveFirewallRuleIDs = append( + affected[peerID].RemoveFirewallRuleIDs, + ruleID, + ) + } + } + } + + return affected +} + +type PeerDeletionUpdate struct { + RemovePeerID string + RemoveFirewallRuleIDs []string + RemoveRouteIDs []route.ID + RemoveFromSourceRanges bool + PeerIP string +} + +func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) { + if aclView := b.cache.peerACLs[peerID]; aclView != nil { + aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool { + return id == updates.RemovePeerID + }) + + if len(updates.RemoveFirewallRuleIDs) > 0 { + aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID) + }) + } + } + + if routesView := b.cache.peerRoutes[peerID]; routesView != nil { + if len(updates.RemoveRouteIDs) > 0 { + routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool { + return slices.Contains(updates.RemoveRouteIDs, routeID) + }) + } + + if updates.RemoveFromSourceRanges { + b.removeIPFromRouteFirewallRules(routesView, updates.PeerIP) + } + } +} + +func (b *NetworkMapBuilder) removeIPFromRouteFirewallRules(routesView *PeerRoutesView, peerIP string) { + sourceIPv4 := peerIP + "/32" + sourceIPv6 := peerIP + "/128" + + rulesToRemove := []string{} + + for _, ruleID := range routesView.RouteFirewallRuleIDs { + if rule := b.cache.globalRouteRules[ruleID]; rule != nil { + rule.SourceRanges = slices.DeleteFunc(rule.SourceRanges, func(source string) bool { + return source == sourceIPv4 || source == sourceIPv6 || source == peerIP + }) + + if len(rule.SourceRanges) == 0 { + rulesToRemove = append(rulesToRemove, ruleID) + } + } + } + + if len(rulesToRemove) > 0 { + routesView.RouteFirewallRuleIDs = slices.DeleteFunc(routesView.RouteFirewallRuleIDs, func(ruleID string) bool { + return slices.Contains(rulesToRemove, ruleID) + }) + } +} + +func (b *NetworkMapBuilder) cleanupUnusedRules() { + usedFirewallRules := make(map[string]struct{}) + usedRouteRules := make(map[string]struct{}) + usedRoutes := make(map[route.ID]struct{}) + + for _, aclView := range b.cache.peerACLs { + for _, ruleID := range aclView.FirewallRuleIDs { + usedFirewallRules[ruleID] = struct{}{} + } + } + + for _, routesView := range b.cache.peerRoutes { + for _, ruleID := range routesView.RouteFirewallRuleIDs { + usedRouteRules[ruleID] = struct{}{} + } + + for _, routeID := range routesView.OwnRouteIDs { + usedRoutes[routeID] = struct{}{} + } + for _, routeID := range routesView.NetworkResourceIDs { + usedRoutes[routeID] = struct{}{} + } + } + + for ruleID := range b.cache.globalRules { + if _, used := usedFirewallRules[ruleID]; !used { + delete(b.cache.globalRules, ruleID) + } + } + + for ruleID := range b.cache.globalRouteRules { + if _, used := usedRouteRules[ruleID]; !used { + delete(b.cache.globalRouteRules, ruleID) + } + } + + for routeID := range b.cache.globalRoutes { + if _, used := usedRoutes[routeID]; !used { + delete(b.cache.globalRoutes, routeID) + } + } +} + +func (b *NetworkMapBuilder) UpdatePeer(peer *nbpeer.Peer) { + b.cache.mu.Lock() + defer b.cache.mu.Unlock() + peerStored, ok := b.cache.globalPeers[peer.ID] + if !ok { + return + } + *peerStored = *peer +} diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index da12f1b70..adf64592a 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -7,16 +7,14 @@ import ( log "github.com/sirupsen/logrus" - "github.com/netbirdio/netbird/shared/management/proto" "github.com/netbirdio/netbird/management/server/telemetry" - "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/management/proto" ) const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse - NetworkMap *types.NetworkMap + Update *proto.SyncResponse } type PeersUpdateManager struct { diff --git a/management/server/user.go b/management/server/user.go index d40d33c6a..66bea314f 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -595,7 +595,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } // prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data. -func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() { +func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() { var eventsToStore []func() if oldUser.IsBlocked() != newUser.IsBlocked() { @@ -621,6 +621,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac }) } + addedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, addedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get added groups for user %s update event: %v", oldUser.Id, err) + } + + for _, group := range addedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta) + }) + } + + removedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, removedGroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get removed groups for user %s update event: %v", oldUser.Id, err) + } + for _, group := range removedGroups { + meta := map[string]any{ + "group": group.Name, "group_id": group.ID, + "is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName, + } + eventsToStore = append(eventsToStore, func() { + am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta) + }) + } + return eventsToStore } @@ -667,9 +696,10 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact peersToExpire = userPeers } + var removedGroups, addedGroups []string if update.AutoGroups != nil && settings.GroupsPropagationEnabled { - removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups) - addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups) + removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups) + addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups) for _, peer := range userPeers { for _, groupID := range removedGroups { if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil { @@ -685,7 +715,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact } updateAccountPeers := len(userPeers) > 0 - userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole) + userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction) return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil } @@ -961,6 +991,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou peer.UserID, peer.ID, accountID, activity.PeerLoginExpired, peer.EventMeta(dnsDomain), ) + + if am.experimentalNetworkMap(accountID) { + am.updatePeerInNetworkMapCache(peer.AccountID, peer) + } } if len(peerIDs) != 0 { diff --git a/route/route.go b/route/route.go index 08a2d37dc..c724e7c7d 100644 --- a/route/route.go +++ b/route/route.go @@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any { func (r *Route) Copy() *Route { route := &Route{ ID: r.ID, + AccountID: r.AccountID, Description: r.Description, NetID: r.NetID, Network: r.Network, diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index 076f2532b..520a83e36 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -55,8 +55,7 @@ func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsE var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.ManagementComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil } diff --git a/shared/management/http/util/util.go b/shared/management/http/util/util.go index 3ae321023..0a29469da 100644 --- a/shared/management/http/util/util.go +++ b/shared/management/http/util/util.go @@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) { httpStatus = http.StatusUnauthorized case status.BadRequest: httpStatus = http.StatusBadRequest + case status.TooManyRequests: + httpStatus = http.StatusTooManyRequests default: } msg = strings.ToLower(err.Error()) diff --git a/shared/management/proto/management.proto b/shared/management/proto/management.proto index ad9454915..ccd92b870 100644 --- a/shared/management/proto/management.proto +++ b/shared/management/proto/management.proto @@ -422,7 +422,7 @@ message DNSConfig { bool ServiceEnable = 1; repeated NameServerGroup NameServerGroups = 2; repeated CustomZone CustomZones = 3; - int64 ForwarderPort = 4; + int64 ForwarderPort = 4 [deprecated = true]; } // CustomZone represents a dns.CustomZone diff --git a/shared/management/status/error.go b/shared/management/status/error.go index 1e914babb..09676847e 100644 --- a/shared/management/status/error.go +++ b/shared/management/status/error.go @@ -37,6 +37,9 @@ const ( // Unauthenticated indicates that user is not authenticated due to absence of valid credentials Unauthenticated Type = 10 + + // TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting) + TooManyRequests Type = 11 ) // Type is a type of the Error diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 31f3372c0..5368b57a2 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -60,8 +60,7 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo var err error conn, err = nbgrpc.CreateConnection(ctx, addr, tlsEnabled, wsproxy.SignalComponent) if err != nil { - log.Printf("createConnection error: %v", err) - return err + return fmt.Errorf("create connection: %w", err) } return nil }