Compare commits

...

54 Commits

Author SHA1 Message Date
pascal
bdf22d0d1f remove auto updates gorm default latest 2025-12-15 17:05:55 +01:00
pascal
0f1a714a9f fix potential panic on setting update 2025-12-15 16:59:25 +01:00
Zoltán Papp
7656b38cca Merge branch 'main' into feat/auto-upgrade 2025-11-12 12:18:24 +01:00
Zoltan Papp
c28275611b Fix agent reference (#4776) 2025-11-11 13:59:32 +01:00
Vlad
56f169eede [management] fix pg db deadlock after app panic (#4772) 2025-11-10 23:43:08 +01:00
Viktor Liu
07cf9d5895 [client] Create networkd.conf.d if it doesn't exist (#4764) 2025-11-08 10:54:37 +01:00
Pascal Fischer
7df49e249d [management ] remove timing logs (#4761) 2025-11-07 20:14:52 +01:00
Pascal Fischer
dbfc8a52c9 [management] remove GLOBAL when disabling foreign keys on mysql (#4615) 2025-11-07 16:03:14 +01:00
Vlad
98ddac07bf [management] remove toAll firewall rule (#4725) 2025-11-07 15:50:58 +01:00
Pascal Fischer
48475ddc05 [management] add pat rate limiting (#4741) 2025-11-07 15:50:18 +01:00
Vlad
6aa4ba7af4 [management] incremental network map builder (#4753) 2025-11-07 10:44:46 +01:00
dependabot[bot]
2e16c9914a [management] Bump github.com/containerd/containerd from 1.7.27 to 1.7.29 (#4756)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.27 to 1.7.29.
- [Release notes](https://github.com/containerd/containerd/releases)
- [Changelog](https://github.com/containerd/containerd/blob/main/RELEASES.md)
- [Commits](https://github.com/containerd/containerd/compare/v1.7.27...v1.7.29)

---
updated-dependencies:
- dependency-name: github.com/containerd/containerd
  dependency-version: 1.7.29
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-06 19:01:44 +03:00
Pascal Fischer
5c29d395b2 [management] activity events on group updates (#4750) 2025-11-06 12:51:14 +01:00
Viktor Liu
229e0038ee [client] Add dns config to debug bundle (#4704) 2025-11-05 17:30:17 +01:00
Viktor Liu
75327d9519 [client] Add login_hint to oidc flows (#4724) 2025-11-05 17:00:20 +01:00
Viktor Liu
c92e6c1b5f [client] Block on all subsystems on shutdown (#4709) 2025-11-05 12:15:37 +01:00
Viktor Liu
641eb5140b [client] Allow INPUT traffic on the compat iptables filter table for nftables (#4742) 2025-11-04 21:56:53 +01:00
Viktor Liu
45c25dca84 [client] Clamp MSS on outbound traffic (#4735) 2025-11-04 17:18:51 +01:00
Viktor Liu
679c58ce47 [client] Set up networkd to ignore ip rules (#4730) 2025-11-04 17:06:35 +01:00
Pascal Fischer
719283c792 [management] update db connection lifecycle configuration (#4740) 2025-11-03 17:40:12 +01:00
Zoltán Papp
030ddae51e Merge branch 'main' into feat/auto-upgrade 2025-10-27 09:57:54 +01:00
Zoltán Papp
6eee52b56e Fix auto update success message check 2025-10-15 14:44:45 +02:00
Zoltán Papp
9313b49625 Fix Windows installer 2025-10-14 17:15:22 +02:00
Zoltán Papp
18f884f769 - fix nil pointer for context
- fix development version handling
- add log lines
2025-10-13 22:10:09 +02:00
Zoltán Papp
1354096c4d Fix windows build 2025-10-13 20:59:56 +02:00
Zoltán Papp
cd19f4d910 Code cleaning in updateState 2025-10-13 20:47:52 +02:00
Zoltán Papp
bab5cd4b41 Clean up temp dir 2025-10-13 20:38:49 +02:00
Zoltán Papp
7d846bf9ba Fix nil pointer exception in expectedSemVer 2025-10-13 20:37:49 +02:00
Zoltán Papp
6200aaf0b0 Fix state handling 2025-10-13 20:21:59 +02:00
Zoltán Papp
7fa926d397 Fix deadlock 2025-10-13 20:14:47 +02:00
Zoltán Papp
9ae48a062a Remove unused codes and remove unnecessary variables 2025-10-13 18:25:29 +02:00
Zoltán Papp
582ff1ff8c Fix auto-update message handling 2025-10-13 18:06:29 +02:00
M. Essam
5556ff36af Merge pull request #4563 from netbirdio/auto-upgrade-mod
Modify client-side behavior
2025-10-12 10:50:45 +03:00
M Essam Hamed
d5ea408cb3 Resolve issues 2025-10-08 19:42:48 +03:00
M Essam Hamed
436d74094b Merge branch 'feat/auto-upgrade' into auto-upgrade-mod 2025-10-08 19:35:20 +03:00
M Essam Hamed
b37ba44015 Resolve issues 2025-10-08 19:33:31 +03:00
M Essam Hamed
0d2ce56e12 Merge branch 'feat/auto-upgrade' into auto-upgrade-mod 2025-10-06 14:58:36 +03:00
M Essam Hamed
723c418966 Merge branch 'main' into feat/auto-upgrade 2025-10-06 14:56:16 +03:00
M Essam Hamed
e04b989a12 Change ProgressBarInfinite to Updating... label 2025-10-01 20:48:47 +03:00
M Essam Hamed
b070304d46 Modify client-side behavior 2025-10-01 17:58:38 +03:00
M. Essam
ad3985ac63 Merge pull request #4504 from netbirdio/sub-feat/auto-upgrade/move-version-networkmap
Move autoUpdateVersion inside NetworkMap
2025-09-21 13:06:33 +03:00
Maycon Santos
50423399f2 Update management/server/account.go
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
2025-09-20 21:56:28 +02:00
M Essam Hamed
02afd4e849 Move to networkMap.PeerConfig 2025-09-16 11:45:54 +03:00
M Essam Hamed
d19f829f65 Move autoUpdateVersion inside NetworkMap 2025-09-15 19:54:10 +03:00
M Essam Hamed
ec47a84afe Remove testing.T.Context() as it's added in go1.24 2025-09-08 12:07:06 +03:00
M. Essam
ecf1e9013e Merge branch 'main' into feat/auto-upgrade 2025-09-07 20:24:56 +03:00
M Essam Hamed
6025eb1962 Add unit tests 2025-09-03 14:10:03 +03:00
M Essam Hamed
59ae92cf8f Refactor handleAutoUpdateVersion to outside handleSync 2025-09-01 15:14:32 +03:00
M Essam Hamed
d2e198bd76 Fix lint 2025-09-01 15:13:14 +03:00
M Essam Hamed
58d48127e0 Define constants for version semantics 2025-09-01 15:13:14 +03:00
M Essam Hamed
84501a3f56 Fix deadlock issues 2025-09-01 15:13:14 +03:00
M Essam Hamed
762b9b7b56 Restructure version.Update to use channel 2025-09-01 15:13:13 +03:00
M Essam Hamed
c6328788ca Resolve comments 2025-09-01 15:13:13 +03:00
M Essam Hamed
bc59749859 Feature: Auto-update client 2025-09-01 15:13:12 +03:00
114 changed files with 10838 additions and 1114 deletions

View File

@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
}
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
if err != nil {
return nil, err
}

View File

@@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey
}
@@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
needsLogin := false
err := WithBackOff(func() error {
@@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
if err != nil {
return nil, err
}

View File

@@ -10,6 +10,8 @@ import (
"path/filepath"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/kardianos/service"
"github.com/spf13/cobra"
@@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error {
svcConfig.Option["LogDirectory"] = dir
}
}
if err := configureSystemdNetworkd(); err != nil {
log.Warnf("failed to configure systemd-networkd: %v", err)
}
}
if runtime.GOOS == "windows" {
@@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{
return fmt.Errorf("uninstall service: %w", err)
}
if runtime.GOOS == "linux" {
if err := cleanupSystemdNetworkd(); err != nil {
log.Warnf("failed to cleanup systemd-networkd configuration: %v", err)
}
}
cmd.Println("NetBird service has been uninstalled")
return nil
},
@@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) {
return status == service.StatusRunning, nil
}
const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
# routes and policy rules managed by NetBird.
[Network]
ManageForeignRoutes=no
ManageForeignRoutingPolicyRules=no
`
)
// configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error {
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
log.Debug("systemd-networkd not in use, skipping configuration")
return nil
}
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err)
}
return nil
}
// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file.
func cleanupSystemdNetworkd() error {
if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) {
return nil
}
if err := os.Remove(networkdConfFile); err != nil {
return fmt.Errorf("remove networkd configuration: %w", err)
}
return nil
}

View File

@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error
var loginResp *proto.LoginResponse

View File

@@ -15,13 +15,13 @@ import (
)
// NewFirewall creates a firewall manager instance
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
if !iface.IsUserspaceBind() {
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
}
// use userspace packet filtering firewall
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}

View File

@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
// FWType is the type for the firewall type
type FWType int
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
// on the linux system we try to user nftables or iptables
// in any case, because we need to allow netbird interface traffic
// so we use AllowNetbird traffic from these firewall managers
// for the userspace packet filtering firewall
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
if !iface.IsUserspaceBind() {
return fm, err
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
if err != nil {
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
}
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
}
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
fm, err := createFW(iface)
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
fm, err := createFW(iface, mtu)
if err != nil {
return nil, fmt.Errorf("create firewall: %s", err)
}
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
return fm, nil
}
func createFW(iface IFaceMapper) (firewall.Manager, error) {
func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
switch check() {
case IPTABLES:
log.Info("creating an iptables firewall manager")
return nbiptables.Create(iface)
return nbiptables.Create(iface, mtu)
case NFTABLES:
log.Info("creating an nftables firewall manager")
return nbnftables.Create(iface)
return nbnftables.Create(iface, mtu)
default:
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
return nil, errors.New("no firewall manager found")
}
}
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
var errUsp error
if fm != nil {
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
} else {
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
}
if errUsp != nil {

View File

@@ -36,7 +36,7 @@ type iFaceMapper interface {
}
// Create iptables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil {
return nil, fmt.Errorf("init iptables: %w", err)
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
ipv4Client: iptablesClient,
}
m.router, err = newRouter(iptablesClient, wgIface)
m.router, err = newRouter(iptablesClient, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}
stateManager.RegisterState(state)

View File

@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
require.NoError(t, err)
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err)
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
}
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second)

View File

@@ -30,17 +30,20 @@ const (
chainPOSTROUTING = "POSTROUTING"
chainPREROUTING = "PREROUTING"
chainFORWARD = "FORWARD"
chainRTNAT = "NETBIRD-RT-NAT"
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
chainRTPRE = "NETBIRD-RT-PRE"
chainRTRDR = "NETBIRD-RT-RDR"
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
routingFinalForwardJump = "ACCEPT"
routingFinalNatJump = "MASQUERADE"
jumpManglePre = "jump-mangle-pre"
jumpNatPre = "jump-nat-pre"
jumpNatPost = "jump-nat-post"
jumpMSSClamp = "jump-mss-clamp"
markManglePre = "mark-mangle-pre"
markManglePost = "mark-mangle-post"
matchSet = "--match-set"
@@ -48,6 +51,9 @@ const (
dnatSuffix = "_dnat"
snatSuffix = "_snat"
fwdSuffix = "_fwd"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
type ruleInfo struct {
@@ -77,16 +83,18 @@ type router struct {
ipsetCounter *ipsetCounter
wgIface iFaceMapper
legacyManagement bool
mtu uint16
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{
iptablesClient: iptablesClient,
rules: make(map[string][]string),
wgIface: wgIface,
mtu: mtu,
ipFwdState: ipfwdstate.NewIPForwardingState(),
}
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
if err != nil {
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
{chainRTPRE, tableMangle},
{chainRTNAT, tableNat},
{chainRTRDR, tableNat},
{chainRTMSSCLAMP, tableMangle},
} {
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
return fmt.Errorf("add jump rules: %w", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
return nil
}
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
// Add jump rule from FORWARD chain in mangle table to our custom chain
jumpRule := []string{
"-j", chainRTMSSCLAMP,
}
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
}
r.rules[jumpMSSClamp] = jumpRule
ruleOut := []string{
"-o", r.wgIface.Name(),
"-p", "tcp",
"--tcp-flags", "SYN,RST", "SYN",
"-j", "TCPMSS",
"--set-mss", fmt.Sprintf("%d", mss),
}
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
}
r.rules["mss-clamp-out"] = ruleOut
return nil
}
func (r *router) insertEstablishedRule(chain string) error {
establishedRule := getConntrackEstablished()
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
}
func (r *router) cleanJumpRules() error {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
if rule, exists := r.rules[ruleKey]; exists {
var table, chain string
switch ruleKey {
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
case jumpNatPre:
table = tableNat
chain = chainPREROUTING
case jumpMSSClamp:
table = tableMangle
chain = chainFORWARD
default:
return fmt.Errorf("unknown jump rule: %s", ruleKey)
}

View File

@@ -14,6 +14,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
nbnet "github.com/netbirdio/netbird/client/net"
)
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "should return a valid iptables manager")
require.NoError(t, manager.init(nil))
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
assert.NoError(t, manager.Reset(), "shouldn't return error")
}()
// Now 5 rules:
// 1. established rule forward in
// 2. estbalished rule forward out
// 3. jump rule to POST nat chain
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
// 7. static return masquerade rule
// 8. mangle prerouting mark rule
// 9. mangle postrouting mark rule
require.Len(t, manager.rules, 9, "should have created rules map")
// 10. jump rule to MSS clamping chain
// 11. MSS clamping rule for outbound traffic
require.Len(t, manager.rules, 11, "should have created rules map")
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "failed to init iptables client")
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
t.Run(testCase.Name, func(t *testing.T) {
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
manager, err := newRouter(iptablesClient, ifaceMock)
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "shouldn't return error")
require.NoError(t, manager.init(nil))
defer func() {
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
require.NoError(t, err, "Failed to create iptables client")
r, err := newRouter(iptablesClient, ifaceMock)
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router manager")
require.NoError(t, r.init(nil))

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"sync"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -11,6 +12,7 @@ type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
ipt, err := Create(s.InterfaceState)
mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
ipt, err := Create(s.InterfaceState, mtu)
if err != nil {
return fmt.Errorf("create iptables manager: %w", err)
}

View File

@@ -100,6 +100,9 @@ type Manager interface {
//
// If comment argument is empty firewall manager should set
// rule ID as comment for the rule
//
// Note: Callers should call Flush() after adding rules to ensure
// they are applied to the kernel and rule handles are refreshed.
AddPeerFiltering(
id []byte,
ip net.IP,

View File

@@ -29,8 +29,6 @@ const (
chainNameForwardFilter = "netbird-acl-forward-filter"
chainNameManglePrerouting = "netbird-mangle-prerouting"
chainNameManglePostrouting = "netbird-mangle-postrouting"
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
)
const flushError = "flush: %w"
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
// createDefaultAllowRules creates default allow rules for the input and output chains
func (m *AclManager) createDefaultAllowRules() error {
expIn := []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: 12,
Len: 4,
},
// mask
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: []byte{0, 0, 0, 0},
Xor: []byte{0, 0, 0, 0},
},
// net address
&expr.Cmp{
Register: 1,
Data: []byte{0, 0, 0, 0},
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
action firewall.Action,
ipset *nftables.Set,
) (*Rule, error) {
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
if r, ok := m.rules[ruleId]; ok {
return &Rule{
nftRule: r.nftRule,
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
}
if err := m.rConn.Flush(); err != nil {
return nil, fmt.Errorf(flushError, err)
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
}
ruleStruct := &Rule{
nftRule: nftRule,
nftRule: nftRule,
// best effort mangle rule
mangleRule: m.createPreroutingRule(expressions, userData),
nftSet: ipset,
ruleID: ruleId,
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
},
)
return m.rConn.AddRule(&nftables.Rule{
nfRule := m.rConn.AddRule(&nftables.Rule{
Table: m.workTable,
Chain: m.chainPrerouting,
Exprs: preroutingExprs,
UserData: userData,
})
if err := m.rConn.Flush(); err != nil {
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
return nil
}
return nfRule
}
func (m *AclManager) createDefaultChains() (err error) {
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
return nil
}
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":"
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
rulesetID := ":" + string(proto) + ":"
if sPort != nil {
rulesetID += sPort.String()
}

View File

@@ -1,11 +1,11 @@
package nftables
import (
"bytes"
"context"
"fmt"
"net"
"net/netip"
"os"
"sync"
"github.com/google/nftables"
@@ -19,13 +19,22 @@ import (
)
const (
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
tableNameNetbird = "netbird"
// envTableName is the environment variable to override the table name
envTableName = "NB_NFTABLES_TABLE"
tableNameFilter = "filter"
chainNameInput = "INPUT"
)
func getTableName() string {
if name := os.Getenv(envTableName); name != "" {
return name
}
return tableNameNetbird
}
// iFaceMapper defines subset methods of interface required for manager
type iFaceMapper interface {
Name() string
@@ -44,16 +53,16 @@ type Manager struct {
}
// Create nftables firewall manager
func Create(wgIface iFaceMapper) (*Manager, error) {
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
m := &Manager{
rConn: &nftables.Conn{},
wgIface: wgIface,
}
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
var err error
m.router, err = newRouter(workTable, wgIface)
m.router, err = newRouter(workTable, wgIface, mtu)
if err != nil {
return nil, fmt.Errorf("create router: %w", err)
}
@@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
NameStr: m.wgIface.Name(),
WGAddress: m.wgIface.Address(),
UserspaceBind: m.wgIface.IsUserspaceBind(),
MTU: m.router.mtu,
},
}); err != nil {
log.Errorf("failed to update state: %v", err)
@@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
m.mutex.Lock()
defer m.mutex.Unlock()
err := m.aclManager.createDefaultAllowRules()
if err != nil {
return fmt.Errorf("failed to create default allow rules: %v", err)
if err := m.aclManager.createDefaultAllowRules(); err != nil {
return fmt.Errorf("create default allow rules: %w", err)
}
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list of chains: %w", err)
}
var chain *nftables.Chain
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
chain = c
break
}
}
if chain == nil {
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
return nil
}
rules, err := m.rConn.GetRules(chain.Table, chain)
if err != nil {
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
}
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
log.Debugf("allow netbird rule already exists: %v", rule)
return nil
}
m.applyAllowNetbirdRules(chain)
err = m.rConn.Flush()
if err != nil {
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
if err := m.rConn.Flush(); err != nil {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
return nil
@@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
m.mutex.Lock()
defer m.mutex.Unlock()
if err := m.resetNetbirdInputRules(); err != nil {
return fmt.Errorf("reset netbird input rules: %v", err)
}
if err := m.router.Reset(); err != nil {
return fmt.Errorf("reset router: %v", err)
}
@@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
return nil
}
func (m *Manager) resetNetbirdInputRules() error {
chains, err := m.rConn.ListChains()
if err != nil {
return fmt.Errorf("list chains: %w", err)
}
m.deleteNetbirdInputRules(chains)
return nil
}
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
for _, c := range chains {
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
rules, err := m.rConn.GetRules(c.Table, c)
if err != nil {
log.Errorf("get rules for chain %q: %v", c.Name, err)
continue
}
m.deleteMatchingRules(rules)
}
}
}
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
for _, r := range rules {
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
if err := m.rConn.DelRule(r); err != nil {
log.Errorf("delete rule: %v", err)
}
}
}
}
func (m *Manager) cleanupNetbirdTables() error {
tables, err := m.rConn.ListTables()
if err != nil {
return fmt.Errorf("list tables: %w", err)
}
tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
@@ -398,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
return nil, fmt.Errorf("list of tables: %w", err)
}
tableName := getTableName()
for _, t := range tables {
if t.Name == tableNameNetbird {
if t.Name == tableName {
m.rConn.DelTable(t)
}
}
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
err = m.rConn.Flush()
return table, err
}
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
rule := &nftables.Rule{
Table: chain.Table,
Chain: chain,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(m.wgIface.Name()),
},
&expr.Verdict{
Kind: expr.VerdictAccept,
},
},
UserData: []byte(allowNetbirdInputRuleID),
}
_ = m.rConn.InsertRule(rule)
}
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
ifName := ifname(m.wgIface.Name())
for _, rule := range existedRules {
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
if len(rule.Exprs) < 4 {
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
continue
}
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
continue
}
return rule
}
}
}
return nil
}
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
rule := &nftables.Rule{
Table: table,

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/sys/unix"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
func TestNftablesManager(t *testing.T) {
// just check on the local interface
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
func TestNftablesManagerRuleOrder(t *testing.T) {
// This test verifies rule insertion order in nftables peer ACLs
// We add accept rule first, then deny rule to test ordering behavior
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
// just check on the local interface
manager, err := Create(mock)
manager, err := Create(mock, iface.DefaultMTU)
require.NoError(t, err)
require.NoError(t, manager.Init(nil))
time.Sleep(time.Second * 3)
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
stdout, stderr := runIptablesSave(t)
verifyIptablesOutput(t, stdout, stderr)
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "failed to create manager")
require.NoError(t, manager.Init(nil))

View File

@@ -16,6 +16,7 @@ import (
"github.com/google/nftables/xt"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
@@ -32,12 +33,17 @@ const (
chainNameRoutingNat = "netbird-rt-postrouting"
chainNameRoutingRdr = "netbird-rt-redirect"
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
dnatSuffix = "_dnat"
snatSuffix = "_snat"
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const refreshRulesMapError = "refresh rules map: %w"
@@ -63,9 +69,10 @@ type router struct {
wgIface iFaceMapper
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
r := &router{
conn: &nftables.Conn{},
workTable: workTable,
@@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
rules: make(map[string]*nftables.Rule),
wgIface: wgIface,
ipFwdState: ipfwdstate.NewIPForwardingState(),
mtu: mtu,
}
r.ipsetCounter = refcounter.New(
@@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
func (r *router) init(workTable *nftables.Table) error {
r.workTable = workTable
if err := r.removeAcceptForwardRules(); err != nil {
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
if err := r.removeAcceptFilterRules(); err != nil {
log.Errorf("failed to clean up rules from filter table: %s", err)
}
if err := r.createContainers(); err != nil {
@@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error {
return nil
}
// Reset cleans existing nftables default forward rules from the system
// Reset cleans existing nftables filter table rules from the system
func (r *router) Reset() error {
// clear without deleting the ipsets, the nf table will be deleted by the caller
r.ipsetCounter.Clear()
var merr *multierror.Error
if err := r.removeAcceptForwardRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
if err := r.removeAcceptFilterRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := r.removeNatPreroutingRules(); err != nil {
@@ -220,11 +228,23 @@ func (r *router) createContainers() error {
Type: nftables.ChainTypeFilter,
})
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
Name: chainNameMangleForward,
Table: r.workTable,
Hooknum: nftables.ChainHookForward,
Priority: nftables.ChainPriorityMangle,
Type: nftables.ChainTypeFilter,
})
// Add the single NAT rule that matches on mark
if err := r.addPostroutingRules(); err != nil {
return fmt.Errorf("add single nat rule: %v", err)
}
if err := r.addMSSClampingRules(); err != nil {
log.Errorf("failed to add MSS clamping rules: %s", err)
}
if err := r.acceptForwardRules(); err != nil {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
@@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error {
return nil
}
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
// TODO: Add IPv6 support
func (r *router) addMSSClampingRules() error {
mss := r.mtu - ipTCPHeaderMinSize
exprsOut := []expr.Any{
&expr.Meta{
Key: expr.MetaKeyOIFNAME,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(r.wgIface.Name()),
},
&expr.Meta{
Key: expr.MetaKeyL4PROTO,
Register: 1,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: []byte{unix.IPPROTO_TCP},
},
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 13,
Len: 1,
},
&expr.Bitwise{
DestRegister: 1,
SourceRegister: 1,
Len: 1,
Mask: []byte{0x02},
Xor: []byte{0x00},
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: []byte{0x00},
},
&expr.Counter{},
&expr.Exthdr{
DestRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
&expr.Cmp{
Op: expr.CmpOpGt,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Immediate{
Register: 1,
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
},
&expr.Exthdr{
SourceRegister: 1,
Type: 2,
Offset: 2,
Len: 2,
Op: expr.ExthdrOpTcpopt,
},
}
r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: r.chains[chainNameMangleForward],
Exprs: exprsOut,
})
return nil
}
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
@@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
// that our traffic is not dropped by existing rules there.
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
// This method also adds INPUT chain rules to allow traffic to the local interface.
func (r *router) acceptForwardRules() error {
if r.filterTable == nil {
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
@@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error {
fw := "iptables"
defer func() {
log.Debugf("Used %s to add accept forward rules", fw)
log.Debugf("Used %s to add accept forward and input rules", fw)
}()
// Try iptables first and fallback to nftables if iptables is not available
@@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
fw = "nftables"
return r.acceptForwardRulesNftables()
return r.acceptFilterRulesNftables()
}
return r.acceptForwardRulesIptables(ipt)
return r.acceptFilterRulesIptables(ipt)
}
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
} else {
log.Debugf("added iptables rule: %v", rule)
log.Debugf("added iptables forward rule: %v", rule)
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
} else {
log.Debugf("added iptables input rule: %v", inputRule)
}
return nberrors.FormatErrorOrNil(merr)
}
@@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string {
}
}
func (r *router) acceptForwardRulesNftables() error {
func (r *router) getAcceptInputRule() []string {
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
}
func (r *router) acceptFilterRulesNftables() error {
intf := ifname(r.wgIface.Name())
// Rule for incoming interface (iif) with counter
iifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
@@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error {
},
}
// Rule for outgoing interface (oif) with counter
oifRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: "FORWARD",
Name: chainNameForward,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookForward,
@@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error {
Exprs: append(oifExprs, getEstablishedExprs(2)...),
UserData: []byte(userDataAcceptForwardRuleOif),
}
r.conn.InsertRule(oifRule)
inputRule := &nftables.Rule{
Table: r.filterTable,
Chain: &nftables.Chain{
Name: chainNameInput,
Table: r.filterTable,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookInput,
Priority: nftables.ChainPriorityFilter,
},
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: intf,
},
&expr.Counter{},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(userDataAcceptInputRule),
}
r.conn.InsertRule(inputRule)
return nil
}
func (r *router) removeAcceptForwardRules() error {
func (r *router) removeAcceptFilterRules() error {
if r.filterTable == nil {
return nil
}
// Try iptables first and fallback to nftables if iptables is not available
ipt, err := iptables.New()
if err != nil {
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
return r.removeAcceptForwardRulesNftables()
return r.removeAcceptFilterRulesNftables()
}
return r.removeAcceptForwardRulesIptables(ipt)
return r.removeAcceptFilterRulesIptables(ipt)
}
func (r *router) removeAcceptForwardRulesNftables() error {
func (r *router) removeAcceptFilterRulesNftables() error {
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
if err != nil {
return fmt.Errorf("list chains: %v", err)
}
for _, chain := range chains {
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
if chain.Table.Name != r.filterTable.Name {
continue
}
if chain.Name != chainNameForward && chain.Name != chainNameInput {
continue
}
@@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
for _, rule := range rules {
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("delete rule: %v", err)
}
@@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
return nil
}
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
var merr *multierror.Error
for _, rule := range r.getAcceptForwardRules() {
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
}
}
inputRule := r.getAcceptInputRule()
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
}
return nberrors.FormatErrorOrNil(merr)
}

View File

@@ -17,6 +17,7 @@ import (
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/test"
"github.com/netbirdio/netbird/client/iface"
)
const (
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
for _, testCase := range test.InsertRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
// need fw manager to init both acl mgr and router for all chains to be present
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
for _, testCase := range test.RemoveRuleTestCases {
t.Run(testCase.Name, func(t *testing.T) {
manager, err := Create(ifaceMock)
manager, err := Create(ifaceMock, iface.DefaultMTU)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
})
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
defer deleteWorkTable()
r, err := newRouter(workTable, ifaceMock)
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
require.NoError(t, err, "Failed to create router")
require.NoError(t, r.init(workTable))

View File

@@ -3,6 +3,7 @@ package nftables
import (
"fmt"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -10,6 +11,7 @@ type InterfaceState struct {
NameStr string `json:"name"`
WGAddress wgaddr.Address `json:"wg_address"`
UserspaceBind bool `json:"userspace_bind"`
MTU uint16 `json:"mtu"`
}
func (i *InterfaceState) Name() string {
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
}
func (s *ShutdownState) Cleanup() error {
nft, err := Create(s.InterfaceState)
mtu := s.InterfaceState.MTU
if mtu == 0 {
mtu = iface.DefaultMTU
}
nft, err := Create(s.InterfaceState, mtu)
if err != nil {
return fmt.Errorf("create nftables manager: %w", err)
}

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"encoding/binary"
"errors"
"fmt"
"net"
@@ -27,7 +28,12 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager"
)
const layerTypeAll = 0
const (
layerTypeAll = 0
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
ipTCPHeaderMinSize = 40
)
const (
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
@@ -36,6 +42,9 @@ const (
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
@@ -122,6 +131,10 @@ type Manager struct {
netstackServices map[serviceKey]struct{}
netstackServiceMutex sync.RWMutex
mtu uint16
mssClampValue uint16
mssClampEnabled bool
}
// decoder for packages
@@ -140,16 +153,16 @@ type decoder struct {
}
// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger)
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
}
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
if err != nil {
return nil, err
}
@@ -157,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
return mgr, nil
}
func parseCreateEnv() (bool, bool) {
var disableConntrack, enableLocalForwarding bool
func parseCreateEnv() (bool, bool, bool) {
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
var err error
if val := os.Getenv(EnvDisableConntrack); val != "" {
disableConntrack, err = strconv.ParseBool(val)
@@ -177,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
}
}
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
disableMSSClamping, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
}
}
return disableConntrack, enableLocalForwarding
return disableConntrack, enableLocalForwarding, disableMSSClamping
}
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
disableConntrack, enableLocalForwarding := parseCreateEnv()
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
m := &Manager{
decoders: sync.Pool{
@@ -213,13 +232,17 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
dnatMappings: make(map[netip.Addr]netip.Addr),
portDNATRules: []portDNATRule{},
netstackServices: make(map[serviceKey]struct{}),
mtu: mtu,
}
m.routingEnabled.Store(false)
if !disableMSSClamping {
m.mssClampEnabled = true
m.mssClampValue = mtu - ipTCPHeaderMinSize
}
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
return nil, fmt.Errorf("update local IPs: %w", err)
}
if disableConntrack {
log.Info("conntrack is disabled")
} else {
@@ -227,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
}
// netstack needs the forwarder for local traffic
if m.netstack && m.localForwarding {
if err := m.initForwarder(); err != nil {
log.Errorf("failed to initialize forwarder: %v", err)
}
}
if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
}
@@ -337,7 +357,7 @@ func (m *Manager) initForwarder() error {
return errors.New("forwarding not supported")
}
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
if err != nil {
m.routingEnabled.Store(false)
return fmt.Errorf("create forwarder: %w", err)
@@ -643,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
return false
}
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
switch d.decoded[1] {
case layers.LayerTypeUDP:
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
return true
}
case layers.LayerTypeTCP:
// Clamp MSS on all TCP SYN packets, including those from local IPs.
// SNATed routed traffic may appear as local IP but still requires clamping.
if m.mssClampEnabled {
m.clampTCPMSS(packetData, d)
}
}
m.trackOutbound(d, srcIP, dstIP, packetData, size)
@@ -691,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
return flags
}
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
if !d.tcp.SYN {
return false
}
if len(d.tcp.Options) == 0 {
return false
}
mssOptionIndex := -1
var currentMSS uint16
for i, opt := range d.tcp.Options {
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
if currentMSS > m.mssClampValue {
mssOptionIndex = i
break
}
}
}
if mssOptionIndex == -1 {
return false
}
ipHeaderSize := int(d.ip4.IHL) * 4
if ipHeaderSize < 20 {
return false
}
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
return false
}
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
return true
}
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
tcpHeaderStart := ipHeaderSize
tcpOptionsStart := tcpHeaderStart + 20
optOffset := tcpOptionsStart
for j := 0; j < mssOptionIndex; j++ {
switch d.tcp.Options[j].OptionType {
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
optOffset++
default:
optOffset += 2 + len(d.tcp.Options[j].OptionData)
}
}
mssValueOffset := optOffset + 2
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
return true
}
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
tcpLayer := packetData[tcpHeaderStart:]
tcpLength := len(packetData) - tcpHeaderStart
tcpLayer[16] = 0
tcpLayer[17] = 0
var pseudoSum uint32
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
pseudoSum += uint32(d.ip4.Protocol)
pseudoSum += uint32(tcpLength)
var sum uint32 = pseudoSum
for i := 0; i < tcpLength-1; i += 2 {
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
}
if tcpLength%2 == 1 {
sum += uint32(tcpLayer[tcpLength-1]) << 8
}
for sum > 0xFFFF {
sum = (sum & 0xFFFF) + (sum >> 16)
}
checksum := ^uint16(sum)
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
}
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
transport := d.decoded[1]
switch transport {

View File

@@ -17,6 +17,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
// Create manager and basic setup
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
manager, _ := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
defer b.Cleanup(func() {
require.NoError(b, manager.Close(nil))
})
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
}
}
// generateTCPPacketWithFlags creates a TCP packet with specific flags
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
b.Helper()
ipv4 := &layers.IPv4{
TTL: 64,
Version: 4,
SrcIP: srcIP,
DstIP: dstIP,
Protocol: layers.IPProtocolTCP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
}
// Set TCP flags
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
return buf.Bytes()
}
func BenchmarkRouteACLs(b *testing.B) {
manager := setupRoutedManager(b, "10.10.0.100/16")
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
}
}
}
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
// This shows the overhead difference between the common case (non-SYN packets, fast path)
// and the rare case (SYN packets that need clamping, expensive path).
func BenchmarkMSSClamping(b *testing.B) {
scenarios := []struct {
name string
description string
genPacket func(*testing.B, net.IP, net.IP) []byte
frequency string
}{
{
name: "syn_needs_clamp",
description: "SYN packet needing MSS clamping",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
frequency: "~0.1% of traffic - EXPENSIVE",
},
{
name: "syn_no_clamp_needed",
description: "SYN packet with already-small MSS",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
},
frequency: "~0.05% of traffic",
},
{
name: "tcp_ack",
description: "Non-SYN TCP packet (ACK, data transfer)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
frequency: "~60-70% of traffic - FAST PATH",
},
{
name: "tcp_psh_ack",
description: "TCP data packet (PSH+ACK)",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
},
frequency: "~10-20% of traffic - FAST PATH",
},
{
name: "udp",
description: "UDP packet",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
frequency: "~20-30% of traffic - FAST PATH",
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
// for the common case (non-SYN TCP packets).
func BenchmarkMSSClampingOverhead(b *testing.B) {
scenarios := []struct {
name string
enabled bool
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "disabled_tcp_ack",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "enabled_tcp_ack",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "disabled_syn_needs_clamp",
enabled: false,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "enabled_syn_needs_clamp",
enabled: true,
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = sc.enabled
if sc.enabled {
manager.mssClampValue = 1240
}
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
func BenchmarkMSSClampingMemory(b *testing.B) {
scenarios := []struct {
name string
genPacket func(*testing.B, net.IP, net.IP) []byte
}{
{
name: "tcp_ack_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
},
},
{
name: "syn_needs_clamp",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
},
},
{
name: "udp_fast_path",
genPacket: func(b *testing.B, src, dst net.IP) []byte {
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
},
},
}
for _, sc := range scenarios {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
}()
manager.mssClampEnabled = true
manager.mssClampValue = 1240
srcIP := net.ParseIP("100.64.0.2")
dstIP := net.ParseIP("8.8.8.8")
packet := sc.genPacket(b, srcIP, dstIP)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
manager.filterOutbound(packet, len(packet))
}
})
}
}
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
b.Helper()
ip := &layers.IPv4{
Version: 4,
IHL: 5,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcp := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Seq: 1000,
Window: 65535,
}
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{
FixLengths: true,
ComputeChecksums: true,
}
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
return buf.Bytes()
}

View File

@@ -12,6 +12,7 @@ import (
wgdevice "golang.zx2c4.com/wireguard/device"
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/mocks"
"github.com/netbirdio/netbird/client/iface/wgaddr"
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
require.NotNil(t, manager)
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(tb, err)
require.NoError(tb, manager.EnableRouting())
require.NotNil(tb, manager)
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
},
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))

View File

@@ -1,6 +1,7 @@
package uspfilter
import (
"encoding/binary"
"fmt"
"net"
"net/netip"
@@ -17,6 +18,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nbiface "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
@@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Errorf("failed to create Manager: %v", err)
return
@@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) {
}
// creating manager instance
manager, err := Create(iface, false, flowLogger)
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
if err != nil {
t.Fatalf("Failed to create Manager: %s", err)
}
@@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) {
func TestProcessOutgoingHooks(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close()
@@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
time.Sleep(time.Second)
@@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
func TestStatefulFirewall_UDPTracking(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
manager.udpTracker.Close() // Close the existing tracker
@@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
SetFilterFunc: func(device.PacketFilter) error { return nil },
}
manager, err := Create(ifaceMock, false, flowLogger)
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
require.NoError(t, err)
t.Cleanup(func() {
require.NoError(t, manager.Close(nil))
@@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) {
require.Equal(t, tc.expected, isAllowed, tc.desc)
}
}
func TestMSSClamping(t *testing.T) {
ifaceMock := &IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
AddressFunc: func() wgaddr.Address {
return wgaddr.Address{
IP: netip.MustParseAddr("100.10.0.100"),
Network: netip.MustParsePrefix("100.10.0.0/16"),
}
},
}
manager, err := Create(ifaceMock, false, flowLogger, 1280)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
}()
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
err = manager.UpdateLocalIPs()
require.NoError(t, err)
srcIP := net.ParseIP("100.10.0.2")
dstIP := net.ParseIP("8.8.8.8")
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
})
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
lowMSS := uint16(1200)
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
})
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
highMSS := uint16(1460)
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
})
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
manager.filterOutbound(packet, len(packet))
d := parsePacket(t, packet)
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
})
}
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
SYN: true,
ACK: true,
Window: 65535,
Options: []layers.TCPOption{
{
OptionType: layers.TCPOptionKindMSS,
OptionLength: 4,
OptionData: binary.BigEndian.AppendUint16(nil, mss),
},
},
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
tb.Helper()
ipLayer := &layers.IPv4{
Version: 4,
TTL: 64,
Protocol: layers.IPProtocolTCP,
SrcIP: srcIP,
DstIP: dstIP,
}
tcpLayer := &layers.TCP{
SrcPort: layers.TCPPort(srcPort),
DstPort: layers.TCPPort(dstPort),
Window: 65535,
}
if flags&uint16(conntrack.TCPSyn) != 0 {
tcpLayer.SYN = true
}
if flags&uint16(conntrack.TCPAck) != 0 {
tcpLayer.ACK = true
}
if flags&uint16(conntrack.TCPFin) != 0 {
tcpLayer.FIN = true
}
if flags&uint16(conntrack.TCPRst) != 0 {
tcpLayer.RST = true
}
if flags&uint16(conntrack.TCPPush) != 0 {
tcpLayer.PSH = true
}
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
require.NoError(tb, err)
buf := gopacket.NewSerializeBuffer()
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
require.NoError(tb, err)
return buf.Bytes()
}

View File

@@ -45,7 +45,7 @@ type Forwarder struct {
netstack bool
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
HandleLocal: false,
})
mtu, err := iface.GetDevice().MTU()
if err != nil {
return nil, fmt.Errorf("get MTU: %w", err)
}
nicID := tcpip.NICID(1)
endpoint := &endpoint{
logger: logger,
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
if err := s.CreateNIC(nicID, endpoint); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
return nil, fmt.Errorf("create NIC: %v", err)
}
protoAddr := tcpip.ProtocolAddress{

View File

@@ -49,7 +49,7 @@ type idleConn struct {
conn *udpPacketConn
}
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
ctx, cancel := context.WithCancel(context.Background())
f := &udpForwarder{
logger: logger,

View File

@@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
func BenchmarkDNATConcurrency(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
func BenchmarkDNATMemoryAllocations(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))
@@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) {
b.Run(sc.name, func(b *testing.B) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(b, err)
defer func() {
require.NoError(b, manager.Close(nil))

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
)
@@ -16,7 +17,7 @@ import (
func TestDNATTranslationCorrectness(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
func TestDNATMappingManagement(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) {
func TestInboundPortDNAT(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))
@@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) {
func TestInboundPortDNATNegative(t *testing.T) {
manager, err := Create(&IFaceMock{
SetFilterFunc: func(device.PacketFilter) error { return nil },
}, false, flowLogger)
}, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
require.NoError(t, manager.Close(nil))

View File

@@ -10,6 +10,7 @@ import (
fw "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/wgaddr"
)
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
},
}
m, err := Create(ifaceMock, false, flowLogger)
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
require.NoError(t, err)
if !statefulMode {

View File

@@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/acl/mocks"
"github.com/netbirdio/netbird/client/internal/netflow"
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)
@@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
}).AnyTimes()
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
require.NoError(t, err)
defer func() {
err = fw.Close(nil)

View File

@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
}
if d.providerConfig.LoginHint != "" {
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
if deviceCode.VerificationURI != "" {
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
}
}
return deviceCode, err
}
func appendLoginHint(uri, loginHint string) string {
if uri == "" || loginHint == "" {
return uri
}
parsedURL, err := url.Parse(uri)
if err != nil {
log.Debugf("failed to parse verification URI for login_hint: %v", err)
return uri
}
query := parsedURL.Query()
query.Set("login_hint", loginHint)
parsedURL.RawQuery = query.Encode()
return parsedURL.String()
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID)

View File

@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful
//
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config)
return authenticateWithDeviceCodeFlow(ctx, config, hint)
}
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config)
return authenticateWithDeviceCodeFlow(ctx, config, hint)
}
return pkceFlow, nil
}
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
}
pkceFlowInfo.ProviderConfig.LoginHint = hint
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
}
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil {
switch s, ok := gstatus.FromError(err); {
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
}
}
deviceFlowInfo.ProviderConfig.LoginHint = hint
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
}

View File

@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
}
}
if p.providerConfig.LoginHint != "" {
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -280,6 +280,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
return wrapErr(err)
}
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
}
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
state.Set(StatusConnected)
@@ -289,15 +293,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
}
<-engineCtx.Done()
c.engineMutex.Lock()
if c.engine != nil && c.engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
if err := c.engine.Stop(); err != nil {
engine := c.engine
c.engine = nil
c.engineMutex.Unlock()
if engine != nil && engine.wgInterface != nil {
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
if err := engine.Stop(); err != nil {
log.Errorf("Failed to stop engine: %v", err)
}
c.engine = nil
}
c.engineMutex.Unlock()
c.statusRecorder.ClientTeardown()
backOff.Reset()
@@ -382,19 +389,12 @@ func (c *ConnectClient) Status() StatusType {
}
func (c *ConnectClient) Stop() error {
if c == nil {
return nil
engine := c.Engine()
if engine != nil {
if err := engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
}
c.engineMutex.Lock()
defer c.engineMutex.Unlock()
if c.engine == nil {
return nil
}
if err := c.engine.Stop(); err != nil {
return fmt.Errorf("stop engine: %w", err)
}
return nil
}

View File

@@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
config.txt: Anonymized configuration information of the NetBird client.
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
@@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
DNS Configuration
The debug bundle includes platform-specific DNS configuration files:
resolv.conf (Unix systems):
- Contains DNS resolver configuration from /etc/resolv.conf
- Includes nameserver entries, search domains, and resolver options
- All IP addresses and domain names are anonymized following the same rules as other files
scutil_dns.txt (macOS only):
- Contains detailed DNS configuration from scutil --dns
- Shows DNS configuration for all network interfaces
- Includes search domains, nameservers, and DNS resolver settings
- All IP addresses and domain names are anonymized
`
const (
@@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() {
if err := g.addFirewallRules(); err != nil {
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
}
if err := g.addDNSInfo(); err != nil {
log.Errorf("failed to add DNS info to debug bundle: %v", err)
}
}
func (g *BundleGenerator) addReadme() error {

View File

@@ -0,0 +1,53 @@
//go:build darwin && !ios
package debug
import (
"bytes"
"context"
"fmt"
"os/exec"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
if err := g.addScutilDNS(); err != nil {
log.Errorf("failed to add scutil DNS output: %v", err)
}
return nil
}
func (g *BundleGenerator) addScutilDNS() error {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
cmd := exec.CommandContext(ctx, "scutil", "--dns")
output, err := cmd.CombinedOutput()
if err != nil {
return fmt.Errorf("execute scutil --dns: %w", err)
}
if len(bytes.TrimSpace(output)) == 0 {
return fmt.Errorf("no scutil DNS output")
}
content := string(output)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
return fmt.Errorf("add scutil DNS output to zip: %w", err)
}
return nil
}

View File

@@ -5,3 +5,7 @@ package debug
func (g *BundleGenerator) addRoutes() error {
return nil
}
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -0,0 +1,16 @@
//go:build unix && !darwin && !android
package debug
import (
log "github.com/sirupsen/logrus"
)
// addDNSInfo collects and adds DNS configuration information to the archive
func (g *BundleGenerator) addDNSInfo() error {
if err := g.addResolvConf(); err != nil {
log.Errorf("failed to add resolv.conf: %v", err)
}
return nil
}

View File

@@ -0,0 +1,7 @@
//go:build !unix
package debug
func (g *BundleGenerator) addDNSInfo() error {
return nil
}

View File

@@ -0,0 +1,29 @@
//go:build unix && !android
package debug
import (
"fmt"
"os"
"strings"
)
const resolvConfPath = "/etc/resolv.conf"
func (g *BundleGenerator) addResolvConf() error {
data, err := os.ReadFile(resolvConfPath)
if err != nil {
return fmt.Errorf("read %s: %w", resolvConfPath, err)
}
content := string(data)
if g.anonymize {
content = g.anonymizer.AnonymizeString(content)
}
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
return fmt.Errorf("add resolv.conf to zip: %w", err)
}
return nil
}

View File

@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
Scope string
// UseIDToken indicates if the id token should be used for authentication
UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it

View File

@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
// DefaultServer dns server object
type DefaultServer struct {
ctx context.Context
ctxCancel context.CancelFunc
ctx context.Context
ctxCancel context.CancelFunc
shutdownWg sync.WaitGroup
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
// This is different from ServiceEnable=false from management which completely disables the DNS service.
disableSys bool
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
// Stop stops the server
func (s *DefaultServer) Stop() {
s.ctxCancel()
s.shutdownWg.Wait()
s.mux.Lock()
defer s.mux.Unlock()
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.applyHostConfig()
s.shutdownWg.Add(1)
go func() {
// persist dns state right away
defer s.shutdownWg.Done()
if err := s.stateManager.PersistState(s.ctx); err != nil {
log.Errorf("Failed to persist dns state: %v", err)
}

View File

@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
return nil, err
}
pf, err := uspfilter.Create(wgIface, false, flowLogger)
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
if err != nil {
t.Fatalf("failed to create uspfilter: %v", err)
return nil, err

View File

@@ -15,6 +15,7 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
"github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/route"
@@ -134,6 +135,8 @@ func (m *Manager) Stop(ctx context.Context) error {
}
}
m.unregisterNetstackServices()
if err := m.dropDNSFirewall(); err != nil {
mErr = multierror.Append(mErr, err)
}
@@ -158,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error {
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
return fmt.Errorf("add udp firewall rule: %w", err)
}
m.fwRules = dnsRules
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
if err != nil {
log.Errorf("failed to add allow DNS router rules, err: %v", err)
return err
return fmt.Errorf("add tcp firewall rule: %w", err)
}
if err := m.firewall.Flush(); err != nil {
return fmt.Errorf("flush: %w", err)
}
m.fwRules = dnsRules
m.tcpRules = tcpRules
m.registerNetstackServices()
return nil
}
func (m *Manager) registerNetstackServices() {
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
if registrar, ok := m.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
}
}
}
func (m *Manager) unregisterNetstackServices() {
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
if registrar, ok := m.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
}
}
}
func (m *Manager) dropDNSFirewall() error {
var mErr *multierror.Error
for _, rule := range m.fwRules {

View File

@@ -50,6 +50,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updatemanager"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -75,6 +76,7 @@ const (
PeerConnectionTimeoutMax = 45000 // ms
PeerConnectionTimeoutMin = 30000 // ms
connInitLimit = 200
disableAutoUpdate = "disabled"
)
var ErrResetConnection = fmt.Errorf("reset connection")
@@ -148,6 +150,8 @@ type Engine struct {
// syncMsgMux is used to guarantee sequential Management Service message processing
syncMsgMux *sync.Mutex
// sshMux protects sshServer field access
sshMux sync.Mutex
config *EngineConfig
mobileDep MobileDependency
@@ -199,9 +203,14 @@ type Engine struct {
connSemaphore *semaphoregroup.SemaphoreGroup
flowManager nftypes.FlowManager
// auto-update
updateManager *updatemanager.UpdateManager
// WireGuard interface monitor
wgIfaceMonitor *WGIfaceMonitor
wgIfaceMonitorWg sync.WaitGroup
wgIfaceMonitor *WGIfaceMonitor
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
shutdownWg sync.WaitGroup
probeStunTurn *relay.StunTurnProbe
}
@@ -298,21 +307,20 @@ func (e *Engine) Stop() error {
e.ingressGatewayMgr = nil
}
e.stopDNSForwarder()
if e.routeManager != nil {
e.routeManager.Stop(e.stateManager)
}
if e.dnsForwardMgr != nil {
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.dnsForwardMgr = nil
}
if e.srWatcher != nil {
e.srWatcher.Close()
}
if e.updateManager != nil {
e.updateManager.Stop()
}
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
e.statusRecorder.UpdateRelayStates([]relay.ProbeResult{})
@@ -325,10 +333,6 @@ func (e *Engine) Stop() error {
e.cancel()
}
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
// Removing peers happens in the conn.Close() asynchronously
time.Sleep(500 * time.Millisecond)
e.close()
// stop flow manager after wg interface is gone
@@ -336,8 +340,6 @@ func (e *Engine) Stop() error {
e.flowManager.Close()
}
log.Infof("stopped Netbird Engine")
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
@@ -348,12 +350,52 @@ func (e *Engine) Stop() error {
log.Errorf("failed to persist state: %v", err)
}
// Stop WireGuard interface monitor and wait for it to exit
e.wgIfaceMonitorWg.Wait()
timeout := e.calculateShutdownTimeout()
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
}
log.Infof("stopped Netbird Engine")
return nil
}
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
func (e *Engine) calculateShutdownTimeout() time.Duration {
peerCount := len(e.peerStore.PeersPubKey())
baseTimeout := 10 * time.Second
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
timeout := baseTimeout + perPeerTimeout
maxTimeout := 30 * time.Second
if timeout > maxTimeout {
timeout = maxTimeout
}
return timeout
}
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
// Connections to remote peers are not established here.
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
@@ -483,14 +525,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
// monitor WireGuard interface lifecycle and restart engine on changes
e.wgIfaceMonitor = NewWGIfaceMonitor()
e.wgIfaceMonitorWg.Add(1)
e.shutdownWg.Add(1)
go func() {
defer e.wgIfaceMonitorWg.Done()
defer e.shutdownWg.Done()
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
e.restartEngine()
e.triggerClientRestart()
} else if err != nil {
log.Warnf("WireGuard interface monitor: %s", err)
}
@@ -499,6 +541,19 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
return nil
}
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.updateManager == nil {
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder, e.stateManager)
}
e.updateManager.CheckUpdateSuccess(e.ctx)
e.handleAutoUpdateVersion(autoUpdateSettings, true)
}
func (e *Engine) createFirewall() error {
if e.config.DisableFirewall {
log.Infof("firewall is disabled")
@@ -506,7 +561,7 @@ func (e *Engine) createFirewall() error {
}
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil || e.firewall == nil {
log.Errorf("failed creating firewall manager: %s", err)
return nil
@@ -674,9 +729,11 @@ func (e *Engine) removeAllPeers() error {
func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey)
e.sshMux.Lock()
if !isNil(e.sshServer) {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.sshMux.Unlock()
e.connMgr.RemovePeerConn(peerKey)
@@ -711,10 +768,44 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
return nil
}
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
if autoUpdateSettings == nil {
return
}
disabled := autoUpdateSettings.Version == disableAutoUpdate
// Stop and cleanup if disabled
if e.updateManager != nil && disabled {
log.Infof("auto-update is disabled, stopping update manager")
e.updateManager.Stop()
e.updateManager = nil
return
}
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
return
}
// Start manager if needed
if e.updateManager == nil {
log.Infof("starting auto-update manager")
e.updateManager = updatemanager.NewUpdateManager(e.statusRecorder, e.stateManager)
}
e.updateManager.Start(e.ctx)
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
e.updateManager.SetVersion(autoUpdateSettings.Version)
}
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
}
if update.GetNetbirdConfig() != nil {
wCfg := update.GetNetbirdConfig()
err := e.updateTURNs(wCfg.GetTurns())
@@ -878,6 +969,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
return nil
}
e.sshMux.Lock()
// start SSH server if it wasn't running
if isNil(e.sshServer) {
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
@@ -885,34 +977,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
}
// nil sshServer means it has not yet been started
var err error
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
if err != nil {
e.sshMux.Unlock()
return fmt.Errorf("create ssh server: %w", err)
}
e.sshServer = server
e.sshMux.Unlock()
go func() {
// blocking
err = e.sshServer.Start()
err = server.Start()
if err != nil {
// will throw error when we stop it even if it is a graceful stop
log.Debugf("stopped SSH server with error %v", err)
}
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
e.sshMux.Lock()
e.sshServer = nil
e.sshMux.Unlock()
log.Infof("stopped SSH server")
}()
} else {
e.sshMux.Unlock()
log.Debugf("SSH server is already running")
}
} else if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
} else {
e.sshMux.Lock()
if !isNil(e.sshServer) {
// Disable SSH server request, so stop it if it was running
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed to stop SSH server %v", err)
}
e.sshServer = nil
}
e.sshServer = nil
e.sshMux.Unlock()
}
return nil
}
@@ -949,7 +1049,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
// E.g. when a new peer has been registered and we are allowed to connect to it.
func (e *Engine) receiveManagementEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
if err != nil {
log.Warnf("failed to get system info with checks: %v", err)
@@ -1125,6 +1227,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys
e.sshMux.Lock()
if !isNil(e.sshServer) {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
@@ -1135,6 +1238,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
}
}
}
e.sshMux.Unlock()
}
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
@@ -1377,7 +1481,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
func (e *Engine) receiveSignalEvents() {
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
// connect to a stream of messages coming from the signal server
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
e.syncMsgMux.Lock()
@@ -1494,12 +1600,14 @@ func (e *Engine) close() {
e.statusRecorder.SetWgIface(nil)
}
e.sshMux.Lock()
if !isNil(e.sshServer) {
err := e.sshServer.Stop()
if err != nil {
log.Warnf("failed stopping the SSH server: %v", err)
}
}
e.sshMux.Unlock()
if e.firewall != nil {
err := e.firewall.Close(e.stateManager)
@@ -1730,8 +1838,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
return allHealthy
}
// restartEngine restarts the engine by cancelling the client context
func (e *Engine) restartEngine() {
// triggerClientRestart triggers a full client restart by cancelling the client context.
// Note: This does NOT just restart the engine - it cancels the entire client context,
// which causes the connect client's retry loop to create a completely new engine.
func (e *Engine) triggerClientRestart() {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
@@ -1753,7 +1863,9 @@ func (e *Engine) startNetworkMonitor() {
}
e.networkMonitor = networkmonitor.New()
e.shutdownWg.Add(1)
go func() {
defer e.shutdownWg.Done()
if err := e.networkMonitor.Listen(e.ctx); err != nil {
if errors.Is(err, context.Canceled) {
log.Infof("network monitor stopped")
@@ -1763,8 +1875,8 @@ func (e *Engine) startNetworkMonitor() {
return
}
log.Infof("Network monitor: detected network change, restarting engine")
e.restartEngine()
log.Infof("Network monitor: detected network change, triggering client restart")
e.triggerClientRestart()
}()
}
@@ -1873,7 +1985,6 @@ func (e *Engine) updateDNSForwarder(
func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
e.registerDNSServices()
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
log.Errorf("failed to start DNS forward: %v", err)
@@ -1893,34 +2004,9 @@ func (e *Engine) stopDNSForwarder() {
log.Errorf("failed to stop DNS forward: %v", err)
}
e.unregisterDNSServices()
e.dnsForwardMgr = nil
}
func (e *Engine) registerDNSServices() {
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
registrar.RegisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
}
}
}
func (e *Engine) unregisterDNSServices() {
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.UDP, nbdns.ForwarderServerPort)
registrar.UnregisterNetstackService(nftypes.TCP, nbdns.ForwarderServerPort)
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", nbdns.ForwarderServerPort)
}
}
}
func (e *Engine) GetNet() (*netstack.Net, error) {
e.syncMsgMux.Lock()
intf := e.wgInterface

View File

@@ -24,6 +24,7 @@ import (
// Manager handles netflow tracking and logging
type Manager struct {
mux sync.Mutex
shutdownWg sync.WaitGroup
logger nftypes.FlowLogger
flowConfig *nftypes.FlowConfig
conntrack nftypes.ConnTracker
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
ctx, cancel := context.WithCancel(context.Background())
m.cancel = cancel
go m.receiveACKs(ctx, flowClient)
go m.startSender(ctx)
m.shutdownWg.Add(2)
go func() {
defer m.shutdownWg.Done()
m.receiveACKs(ctx, flowClient)
}()
go func() {
defer m.shutdownWg.Done()
m.startSender(ctx)
}()
return nil
}
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
// Close cleans up all resources
func (m *Manager) Close() {
m.mux.Lock()
defer m.mux.Unlock()
if err := m.disableFlow(); err != nil {
log.Warnf("failed to disable flow manager: %v", err)
}
m.mux.Unlock()
m.shutdownWg.Wait()
}
// GetLogger returns the flow logger

View File

@@ -19,11 +19,10 @@ type SRWatcher struct {
signalClient chNotifier
relayManager chNotifier
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
listeners map[chan struct{}]struct{}
mu sync.Mutex
iFaceDiscover stdnet.ExternalIFaceDiscover
iceConfig ice.Config
cancelIceMonitor context.CancelFunc
}

View File

@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
if isController(w.config) {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}

View File

@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
}
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it

View File

@@ -81,6 +81,7 @@ type DefaultManager struct {
ctx context.Context
stop context.CancelFunc
mux sync.Mutex
shutdownWg sync.WaitGroup
clientNetworks map[route.HAUniqueID]*client.Watcher
routeSelector *routeselector.RouteSelector
serverRouter *server.Router
@@ -283,6 +284,7 @@ func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
// Stop stops the manager watchers and clean firewall rules
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
m.stop()
m.shutdownWg.Wait()
if m.serverRouter != nil {
m.serverRouter.CleanUp()
}
@@ -485,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
}
clientNetworkWatcher := client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
}
@@ -527,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
}
clientNetworkWatcher = client.NewWatcher(config)
m.clientNetworks[id] = clientNetworkWatcher
go clientNetworkWatcher.Start()
m.shutdownWg.Add(1)
go func() {
defer m.shutdownWg.Done()
clientNetworkWatcher.Start()
}()
}
update := client.RoutesUpdate{
UpdateSerial: updateSerial,

View File

@@ -9,8 +9,6 @@ import (
"github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route"
)
@@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
defer rs.mu.RUnlock()
if rs.deselectAll {
log.Debugf("Route %s not selected (deselect all)", routeID)
return false
}
_, deselected := rs.deselectedRoutes[routeID]
isSelected := !deselected
log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected)
return isSelected
}

View File

@@ -0,0 +1,386 @@
package updatemanager
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
v "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/version"
)
const (
latestVersion = "latest"
// this version will be ignored
developmentVersion = "development"
)
type UpdateInterface interface {
StopWatch()
SetDaemonVersion(newVersion string) bool
SetOnUpdateListener(updateFn func())
LatestVersion() *v.Version
StartFetcher()
}
type UpdateState struct {
PreUpdateVersion string
TargetVersion string
}
func (u UpdateState) Name() string {
return "autoUpdate"
}
type UpdateManager struct {
statusRecorder *peer.Status
stateManager *statemanager.Manager
lastTrigger time.Time
mgmUpdateChan chan struct{}
updateChannel chan struct{}
currentVersion string
update UpdateInterface
wg sync.WaitGroup
cancel context.CancelFunc
expectedVersion *v.Version
updateToLatestVersion bool
// updateMutex protect update and expectedVersion fields
updateMutex sync.Mutex
}
func NewUpdateManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *UpdateManager {
manager := &UpdateManager{
statusRecorder: statusRecorder,
stateManager: stateManager,
mgmUpdateChan: make(chan struct{}, 1),
updateChannel: make(chan struct{}, 1),
currentVersion: version.NetbirdVersion(),
update: version.NewUpdate("nb/client"),
}
return manager
}
// CheckUpdateSuccess checks if the update was successful. It works without to start the update manager.
func (u *UpdateManager) CheckUpdateSuccess(ctx context.Context) {
u.updateStateManager(ctx)
return
}
func (u *UpdateManager) Start(ctx context.Context) {
if u.cancel != nil {
log.Errorf("UpdateManager already started")
return
}
u.update.SetDaemonVersion(u.currentVersion)
u.update.SetOnUpdateListener(func() {
select {
case u.updateChannel <- struct{}{}:
default:
}
})
go u.update.StartFetcher()
ctx, cancel := context.WithCancel(ctx)
u.cancel = cancel
u.wg.Add(1)
go u.updateLoop(ctx)
}
func (u *UpdateManager) SetVersion(expectedVersion string) {
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
if u.cancel == nil {
log.Errorf("UpdateManager not started")
return
}
u.updateMutex.Lock()
defer u.updateMutex.Unlock()
if expectedVersion == latestVersion {
u.updateToLatestVersion = true
u.expectedVersion = nil
} else {
expectedSemVer, err := v.NewVersion(expectedVersion)
if err != nil {
log.Errorf("Error parsing version: %v", err)
return
}
if u.expectedVersion != nil && u.expectedVersion.Equal(expectedSemVer) {
return
}
u.expectedVersion = expectedSemVer
u.updateToLatestVersion = false
}
select {
case u.mgmUpdateChan <- struct{}{}:
default:
}
}
func (u *UpdateManager) Stop() {
if u.cancel == nil {
return
}
u.cancel()
u.updateMutex.Lock()
if u.update != nil {
u.update.StopWatch()
u.update = nil
}
u.updateMutex.Unlock()
u.wg.Wait()
}
func (u *UpdateManager) onContextCancel() {
if u.cancel == nil {
return
}
u.updateMutex.Lock()
defer u.updateMutex.Unlock()
if u.update != nil {
u.update.StopWatch()
u.update = nil
}
}
func (u *UpdateManager) updateLoop(ctx context.Context) {
defer u.wg.Done()
for {
select {
case <-ctx.Done():
u.onContextCancel()
return
case <-u.mgmUpdateChan:
case <-u.updateChannel:
log.Infof("fetched new version info")
}
u.handleUpdate(ctx)
}
}
func (u *UpdateManager) handleUpdate(ctx context.Context) {
var updateVersion *v.Version
u.updateMutex.Lock()
if u.update == nil {
u.updateMutex.Unlock()
return
}
expectedVersion := u.expectedVersion
useLatest := u.updateToLatestVersion
curLatestVersion := u.update.LatestVersion()
u.updateMutex.Unlock()
switch {
// Resolve "latest" to actual version
case useLatest:
if curLatestVersion == nil {
log.Tracef("latest version not fetched yet")
return
}
updateVersion = curLatestVersion
// Update to specific version
case expectedVersion != nil:
updateVersion = expectedVersion
default:
log.Debugf("no expected version information set")
return
}
log.Debugf("checking update option, current version: %s, target version: %s", u.currentVersion, updateVersion)
if !u.shouldUpdate(updateVersion) {
return
}
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
u.lastTrigger = time.Now()
log.Debugf("Auto-update triggered, current version: %s, target version: %s", u.currentVersion, updateVersion)
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Automatically updating client",
"Your client version is older than auto-update version set in Management, updating client now.",
nil,
)
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "show"},
)
updateState := UpdateState{
PreUpdateVersion: u.currentVersion,
TargetVersion: updateVersion.String(),
}
if err := u.stateManager.UpdateState(updateState); err != nil {
log.Warnf("failed to update state: %v", err)
} else {
if err = u.stateManager.PersistState(ctx); err != nil {
log.Warnf("failed to persist state: %v", err)
}
}
if err := u.triggerUpdate(ctx, updateVersion.String()); err != nil {
log.Errorf("Error triggering auto-update: %v", err)
u.statusRecorder.PublishEvent(
cProto.SystemEvent_ERROR,
cProto.SystemEvent_SYSTEM,
"Auto-update failed",
fmt.Sprintf("Auto-update failed: %v", err),
nil,
)
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"",
"",
map[string]string{"progress_window": "hide"},
)
}
}
func (u *UpdateManager) updateStateManager(ctx context.Context) {
stateType := &UpdateState{}
u.stateManager.RegisterState(stateType)
if err := u.stateManager.LoadState(stateType); err != nil {
log.Errorf("failed to load state: %v", err)
return
}
state := u.stateManager.GetState(stateType)
if state == nil {
return
}
updateState, ok := state.(*UpdateState)
if !ok {
log.Errorf("failed to cast state to UpdateState")
return
}
log.Debugf("autoUpdate state loaded, %v", *updateState)
if updateState.TargetVersion == u.currentVersion {
log.Infof("published notification event")
u.statusRecorder.PublishEvent(
cProto.SystemEvent_INFO,
cProto.SystemEvent_SYSTEM,
"Auto-update completed",
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", u.currentVersion),
nil,
)
}
if err := u.stateManager.DeleteState(updateState); err != nil {
log.Errorf("failed to delete state: %v", err)
} else if err = u.stateManager.PersistState(ctx); err != nil {
log.Errorf("failed to persist state: %v", err)
}
}
func (u *UpdateManager) shouldUpdate(updateVersion *v.Version) bool {
if u.currentVersion == developmentVersion {
log.Debugf("skipping auto-update, running development version")
return false
}
currentVersion, err := v.NewVersion(u.currentVersion)
if err != nil {
log.Errorf("error checking for update, error parsing version `%s`: %v", u.currentVersion, err)
return false
}
if currentVersion.GreaterThanOrEqual(updateVersion) {
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", u.currentVersion, updateVersion)
return false
}
if time.Since(u.lastTrigger) < 5*time.Minute {
log.Debugf("skipping auto-update, last update was %s ago", time.Since(u.lastTrigger))
return false
}
return true
}
func downloadFileToTemporaryDir(ctx context.Context, fileURL string) (string, error) { //nolint:unused
tempDir, err := os.MkdirTemp("", "netbird-installer-*")
if err != nil {
return "", fmt.Errorf("error creating temporary directory: %w", err)
}
// Clean up temp directory on error
var success bool
defer func() {
if !success {
if err := os.RemoveAll(tempDir); err != nil {
log.Errorf("error cleaning up temporary directory: %v", err)
}
}
}()
fileNameParts := strings.Split(fileURL, "/")
out, err := os.Create(filepath.Join(tempDir, fileNameParts[len(fileNameParts)-1]))
if err != nil {
return "", fmt.Errorf("error creating temporary file: %w", err)
}
defer func() {
if err := out.Close(); err != nil {
log.Errorf("error closing temporary file: %v", err)
}
}()
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
if err != nil {
return "", fmt.Errorf("error creating file download request: %w", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("error downloading file: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
log.Errorf("Error closing response body: %v", err)
}
}()
if resp.StatusCode != http.StatusOK {
log.Errorf("error downloading update file, received status code: %d", resp.StatusCode)
return "", fmt.Errorf("error downloading file, received status code: %d", resp.StatusCode)
}
_, err = io.Copy(out, resp.Body)
if err != nil {
return "", fmt.Errorf("error downloading file: %w", err)
}
log.Infof("downloaded update file to %s", out.Name())
success = true // Mark success to prevent cleanup
return out.Name(), nil
}

View File

@@ -0,0 +1,213 @@
package updatemanager
import (
"context"
"fmt"
v "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/statemanager"
"path"
"testing"
"time"
)
func (u *UpdateManager) WithCustomVersionUpdate(versionUpdate UpdateInterface) *UpdateManager {
u.update = versionUpdate
return u
}
type versionUpdateMock struct {
latestVersion *v.Version
onUpdate func()
}
func (v versionUpdateMock) StopWatch() {}
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
return false
}
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
v.onUpdate = updateFn
}
func (v versionUpdateMock) LatestVersion() *v.Version {
return v.latestVersion
}
func (v versionUpdateMock) StartFetcher() {}
func Test_LatestVersion(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
initialLatestVersion *v.Version
latestVersion *v.Version
shouldUpdateInit bool
shouldUpdateLater bool
}{
{
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
daemonVersion: "1.0.0",
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
latestVersion: v.Must(v.NewSemver("1.0.2")),
shouldUpdateInit: true,
shouldUpdateLater: false,
},
{
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
daemonVersion: "1.0.0",
initialLatestVersion: nil,
latestVersion: v.Must(v.NewSemver("1.0.1")),
shouldUpdateInit: false,
shouldUpdateLater: true,
},
}
for idx, c := range testMatrix {
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(mockUpdate)
targetVersionChan := make(chan string, 1)
m.updateFunc = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion("latest")
var triggeredInit bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.initialLatestVersion.String() {
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
}
triggeredInit = true
case <-time.After(10 * time.Millisecond):
triggeredInit = false
}
if triggeredInit != c.shouldUpdateInit {
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
}
mockUpdate.latestVersion = c.latestVersion
mockUpdate.onUpdate()
var triggeredLater bool
select {
case targetVersion := <-targetVersionChan:
if targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
}
triggeredLater = true
case <-time.After(10 * time.Millisecond):
triggeredLater = false
}
if triggeredLater != c.shouldUpdateLater {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
}
m.Stop()
}
}
func Test_HandleUpdate(t *testing.T) {
testMatrix := []struct {
name string
daemonVersion string
latestVersion *v.Version
expectedVersion string
shouldUpdate bool
}{
{
name: "Update to a specific version should update regardless of if latestVersion is available yet",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.56.0",
shouldUpdate: true,
},
{
name: "Update to specific version should not update if version matches",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.55.0",
shouldUpdate: false,
},
{
name: "Update to specific version should not update if current version is newer",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "0.54.0",
shouldUpdate: false,
},
{
name: "Update to latest version should update if latest is newer",
daemonVersion: "0.55.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: true,
},
{
name: "Update to latest version should not update if latest == current",
daemonVersion: "0.56.0",
latestVersion: v.Must(v.NewSemver("0.56.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if daemon version is invalid",
daemonVersion: "development",
latestVersion: v.Must(v.NewSemver("1.0.0")),
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expecting latest and latest version is unavailable",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "latest",
shouldUpdate: false,
},
{
name: "Should not update if expected version is invalid",
daemonVersion: "0.55.0",
latestVersion: nil,
expectedVersion: "development",
shouldUpdate: false,
},
}
for idx, c := range testMatrix {
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
m := NewUpdateManager(peer.NewRecorder(""), statemanager.New(tmpFile)).WithCustomVersionUpdate(&versionUpdateMock{latestVersion: c.latestVersion})
targetVersionChan := make(chan string, 1)
m.updateFunc = func(ctx context.Context, targetVersion string) error {
targetVersionChan <- targetVersion
return nil
}
m.currentVersion = c.daemonVersion
m.Start(context.Background())
m.SetVersion(c.expectedVersion)
var updateTriggered bool
select {
case targetVersion := <-targetVersionChan:
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
}
updateTriggered = true
case <-time.After(10 * time.Millisecond):
updateTriggered = false
}
if updateTriggered != c.shouldUpdate {
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
}
m.Stop()
}
}

View File

@@ -0,0 +1,118 @@
//go:build darwin
package updatemanager
import (
"context"
"fmt"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"syscall"
)
const (
pkgDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
)
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
outBytes, err := cmd.Output()
if err != nil && cmd.ProcessState.ExitCode() == 1 {
// Not installed using pkg file, thus installed using Homebrew
return updateHomeBrew(ctx)
}
// Installed using pkg file
path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(targetVersion))
if err != nil {
return fmt.Errorf("error downloading update file: %w", err)
}
volume := "/"
for _, v := range strings.Split(string(outBytes), "\n") {
trimmed := strings.TrimSpace(v)
if strings.HasPrefix(trimmed, "volume: ") {
volume = strings.Split(trimmed, ": ")[1]
}
}
cmd = exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
err = cmd.Start()
if err != nil {
return fmt.Errorf("error running pkg file: %w", err)
}
err = cmd.Process.Release()
return err
}
func updateHomeBrew(ctx context.Context) error {
// Homebrew must be run as a non-root user
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
fileInfo, err := os.Stat("/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/")
if err != nil {
return fmt.Errorf("error getting homebrew installation path info: %w", err)
}
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
if !ok {
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
}
// Get username from UID
installer, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
if err != nil {
return fmt.Errorf("error looking up brew installer user: %w", err)
}
userName := installer.Name
// Get user HOME, required for brew to run correctly
// https://github.com/Homebrew/brew/issues/15833
homeDir := installer.HomeDir
// Homebrew does not support installing specific versions
// Thus it will always update to latest and ignore targetVersion
upgradeArgs := []string{"-u", userName, "/opt/homebrew/bin/brew", "upgrade", "netbirdio/tap/netbird"}
// Check if netbird-ui is installed
cmd := exec.CommandContext(ctx, "brew", "info", "--json", "netbirdio/tap/netbird-ui")
err = cmd.Run()
if err == nil {
// netbird-ui is installed
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
}
cmd = exec.CommandContext(ctx, "sudo", upgradeArgs...)
cmd.Env = append(cmd.Env, "HOME="+homeDir)
// Homebrew upgrade doesn't restart the client on its own
// So we have to wait for it to finish running and ensure it's done
// And then basically restart the netbird service
err = cmd.Run()
if err != nil {
return fmt.Errorf("error running brew upgrade: %w", err)
}
currentPID := os.Getpid()
// Restart netbird service after the fact
// This is a workaround since attempting to restart using launchctl will kill the service and die before starting
// the service again as it's a child process
// using SIGTERM should ensure a clean shutdown
process, err := os.FindProcess(currentPID)
if err != nil {
return fmt.Errorf("error finding current process: %w", err)
}
err = process.Signal(syscall.SIGTERM)
if err != nil {
return fmt.Errorf("error sending SIGTERM to current process: %w", err)
}
// We're dying now, which should restart us
return nil
}
func urlWithVersionArch(version string) string {
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -0,0 +1,10 @@
//go:build freebsd
package updatemanager
import "context"
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
// TODO: Implement
return nil
}

View File

@@ -0,0 +1,10 @@
//go:build js
package updatemanager
import "context"
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
// TODO: Implement
return nil
}

View File

@@ -0,0 +1,10 @@
//go:build linux
package updatemanager
import "context"
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
// TODO: Implement
return nil
}

View File

@@ -0,0 +1,96 @@
//go:build windows
package updatemanager
import (
"context"
"os/exec"
"runtime"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/windows/registry"
)
const (
msiDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
exeDownloadURL = "https://github.com/netbirdio/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
installerEXE installerType = "EXE"
installerMSI installerType = "MSI"
)
type installerType string
func (u *UpdateManager) triggerUpdate(ctx context.Context, targetVersion string) error {
method := installation()
return install(ctx, method, targetVersion)
}
func installation() installerType {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, uninstallKeyPath64, registry.QUERY_VALUE)
if err != nil {
k, err = registry.OpenKey(registry.LOCAL_MACHINE, uninstallKeyPath32, registry.QUERY_VALUE)
if err != nil {
return installerMSI
} else {
err = k.Close()
if err != nil {
log.Warnf("Error closing registry key: %v", err)
}
}
} else {
err = k.Close()
if err != nil {
log.Warnf("Error closing registry key: %v", err)
}
}
return installerEXE
}
func install(ctx context.Context, installerType installerType, targetVersion string) error {
path, err := downloadFileToTemporaryDir(ctx, urlWithVersionArch(installerType, targetVersion))
if err != nil {
return err
}
log.Infof("start installation %s", path)
var cmd *exec.Cmd
if installerType == installerEXE {
cmd = exec.CommandContext(ctx, path, "/S")
} else {
cmd = exec.CommandContext(ctx, "msiexec", "/quiet", "/i", path)
}
// Detach the process from the parent
cmd.SysProcAttr = &syscall.SysProcAttr{
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
}
if err := cmd.Start(); err != nil {
log.Errorf("error starting installer: %v", err)
return err
}
if err := cmd.Process.Release(); err != nil {
log.Errorf("error releasing installer process: %v", err)
return err
}
log.Infof("installer started successfully: %s", path)
return nil
}
func urlWithVersionArch(it installerType, version string) string {
var url string
if it == installerEXE {
url = exeDownloadURL
} else {
url = msiDownloadURL
}
url = strings.ReplaceAll(url, "%version", version)
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
}

View File

@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile,
})
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
if err != nil {
return err.Error()
}

View File

@@ -279,8 +279,10 @@ type LoginRequest struct {
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
// hint is used to pre-fill the email/username field during SSO authentication
Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *LoginRequest) Reset() {
@@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 {
return 0
}
func (x *LoginRequest) GetHint() string {
if x != nil && x.Hint != nil {
return *x.Hint
}
return ""
}
type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" +
"\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
"\fEmptyRequest\"\xc3\x0e\n" +
"\fEmptyRequest\"\xe5\x0e\n" +
"\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" +
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" +
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
"\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" +
@@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_block_inboundB\x0e\n" +
"\f_profileNameB\v\n" +
"\t_usernameB\x06\n" +
"\x04_mtu\"\xb5\x01\n" +
"\x04_mtuB\a\n" +
"\x05_hint\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +

View File

@@ -158,6 +158,9 @@ message LoginRequest {
optional string username = 31;
optional int64 mtu = 32;
// hint is used to pre-fill the email/username field during SSO authentication
optional string hint = 33;
}
message LoginResponse {

View File

@@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting)
if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
if err != nil {
state.Set(internal.StatusLoginFailed)
return nil, err

View File

@@ -93,13 +93,14 @@ func main() {
showLoginURL: flags.showLoginURL,
showDebug: flags.showDebug,
showProfiles: flags.showProfiles,
showUpdate: flags.showUpdate,
})
// Watch for theme/settings changes to update the icon.
go watchSettingsChanges(a, client)
// Run in window mode if any UI flag was set.
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles {
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showUpdate {
a.Run()
return
}
@@ -127,6 +128,7 @@ type cliFlags struct {
showDebug bool
showLoginURL bool
errorMsg string
showUpdate bool
saveLogsInFile bool
}
@@ -146,6 +148,7 @@ func parseFlags() *cliFlags {
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
flag.Parse()
return &flags
}
@@ -296,6 +299,8 @@ type serviceClient struct {
mExitNodeDeselectAll *systray.MenuItem
logFile string
wLoginURL fyne.Window
wUpdateProgress fyne.Window
updateContextCancel context.CancelFunc
connectCancel context.CancelFunc
}
@@ -314,6 +319,7 @@ type newServiceClientArgs struct {
showDebug bool
showLoginURL bool
showProfiles bool
showUpdate bool
}
// newServiceClient instance constructor
@@ -331,7 +337,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
showAdvancedSettings: args.showSettings,
showNetworks: args.showNetworks,
update: version.NewUpdate("nb/client-ui"),
update: version.NewUpdateAndStart("nb/client-ui"),
}
s.eventHandler = newEventHandler(s)
@@ -349,6 +355,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
s.showDebugUI()
case args.showProfiles:
s.showProfilesUI()
case args.showUpdate:
s.showUpdateProgress(ctx)
}
return s
@@ -394,6 +402,30 @@ func (s *serviceClient) updateIcon() {
s.updateIndicationLock.Unlock()
}
func (s *serviceClient) showUpdateProgress(ctx context.Context) {
s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
loadingLabel := widget.NewLabel("Updating")
s.wUpdateProgress.SetContent(container.NewGridWithRows(2, widget.NewLabel("Your client version is older than auto-update version set in Management, updating client now."), loadingLabel))
s.wUpdateProgress.Show()
go func() {
dotCount := 0
for {
select {
case <-ctx.Done():
return
case <-time.After(time.Second):
dotCount++
dotCount %= 4
loadingLabel.SetText(fmt.Sprintf("Updating%s", strings.Repeat(".", dotCount)))
}
}
}()
s.wUpdateProgress.CenterOnScreen()
s.wUpdateProgress.SetFixedSize(true)
s.wUpdateProgress.SetCloseIntercept(func() {})
s.wUpdateProgress.RequestFocus()
}
func (s *serviceClient) showSettingsUI() {
// Check if update settings are disabled by daemon
features, err := s.getFeatures()
@@ -610,11 +642,20 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err)
}
loginResp, err := conn.Login(ctx, &proto.LoginRequest{
loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name,
Username: &currUser.Username,
})
}
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginReq.Hint = &profileState.Email
}
loginResp, err := conn.Login(ctx, loginReq)
if err != nil {
return nil, fmt.Errorf("login to management: %w", err)
}
@@ -934,6 +975,29 @@ func (s *serviceClient) onTrayReady() {
s.updateExitNodes()
}
})
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
if windowAction, ok := event.Metadata["progress_window"]; ok {
log.Debugf("window action: %v", windowAction)
if windowAction == "show" {
log.Debugf("Inside show")
if s.updateContextCancel != nil {
s.updateContextCancel()
s.updateContextCancel = nil
}
subCtx, cancel := context.WithCancel(s.ctx)
go s.eventHandler.runSelfCommand(subCtx, "update", "true")
s.updateContextCancel = cancel
}
if windowAction == "hide" {
log.Debugf("Inside hide")
if s.updateContextCancel != nil {
s.updateContextCancel()
s.updateContextCancel = nil
}
}
}
})
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)

10
go.mod
View File

@@ -56,6 +56,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1
github.com/mdlayher/socket v0.5.1
@@ -102,11 +103,12 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/mod v0.25.0
golang.org/x/mod v0.26.0
golang.org/x/net v0.42.0
golang.org/x/oauth2 v0.28.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.16.0
golang.org/x/term v0.33.0
golang.org/x/time v0.12.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
@@ -146,7 +148,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.27 // indirect
github.com/containerd/containerd v1.7.29 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -183,7 +185,6 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
@@ -245,7 +246,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.34.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect

16
go.sum
View File

@@ -142,8 +142,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
@@ -818,8 +818,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -880,8 +880,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -993,8 +993,8 @@ golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=

View File

@@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
s.update = version.NewUpdate("nb/management")
s.update = version.NewUpdateAndStart("nb/management")
s.update.SetDaemonVersion(version.NetbirdVersion())
s.update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())

View File

@@ -1,11 +1,19 @@
package main
import (
"github.com/netbirdio/netbird/management/cmd"
"log"
"net/http"
// nolint:gosec
_ "net/http/pprof"
"os"
"github.com/netbirdio/netbird/management/cmd"
)
func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}

View File

@@ -53,6 +53,9 @@ const (
peerSchedulerRetryInterval = 3 * time.Second
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
)
type userLoggedInOnce bool
@@ -109,6 +112,11 @@ type DefaultAccountManager struct {
loginFilter *loginFilter
disableDefaultPolicy bool
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
}
func isUniqueConstraintError(err error) bool {
@@ -196,6 +204,18 @@ func BuildManager(
log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
}()
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
am := &DefaultAccountManager{
Store: store,
geo: geo,
@@ -217,6 +237,10 @@ func BuildManager(
permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
}
am.startWarmup(ctx)
@@ -340,7 +364,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain {
oldSettings.DNSDomain != newSettings.DNSDomain ||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
updateAccountPeers = true
}
@@ -376,6 +401,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
@@ -395,6 +421,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
go am.UpdateAccountPeers(ctx, accountID)
}
@@ -477,6 +506,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
}
}
func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{
"version": newSettings.AutoUpdateVersion,
})
}
}
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
@@ -1477,6 +1514,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil {
return err
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
@@ -1641,11 +1682,6 @@ func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) boo
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
}()
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
@@ -2129,6 +2165,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
}
if updateNetworkMap {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
am.BufferUpdateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -128,4 +128,5 @@ type Manager interface {
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
AllowSync(string, uint64) bool
RecalculateNetworkMapCache(ctx context.Context, accountId string) error
}

View File

@@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
group := types.Group{
@@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
@@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := types.Group{
@@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1377,6 +1422,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
for drained := false; !drained; {
select {
case <-updMsg:
default:
drained = true
}
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1736,7 +1789,9 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
err := hasNilField(account)
if err != nil {
t.Fatal(err)

View File

@@ -180,6 +180,8 @@ const (
UserApproved Activity = 89
UserRejected Activity = 90
AccountAutoUpdateVersionUpdated Activity = 91
AccountDeleted Activity = 99999
)
@@ -286,8 +288,11 @@ var activityMap = map[Activity]Code{
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
UserApproved: {"User approved", "user.approve"},
UserRejected: {"User rejected", "user.reject"},
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
}
// StringCode returns a string code of the activity

View File

@@ -7,6 +7,7 @@ import (
"path/filepath"
"runtime"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
@@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e
return nil, err
}
if storeEngine == types.SqliteStoreEngine {
sqlDB.SetMaxOpenConns(1)
} else {
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
if err != nil {
conns = runtime.NumCPU()
}
sqlDB.SetMaxOpenConns(conns)
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
if err != nil {
conns = runtime.NumCPU()
}
if storeEngine == types.SqliteStoreEngine {
conns = 1
}
sqlDB.SetMaxOpenConns(conns)
sqlDB.SetMaxIdleConns(conns)
sqlDB.SetConnMaxLifetime(time.Hour)
sqlDB.SetConnMaxIdleTime(3 * time.Minute)
log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
conns, conns, time.Hour, 3*time.Minute)
return db, nil
}

View File

@@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -138,6 +141,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
@@ -157,11 +165,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
@@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -335,6 +347,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
if oldGroup.Name != newGroup.Name {
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"old_name": oldGroup.Name,
"new_name": newGroup.Name,
}
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta)
})
}
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
eventsToStore = append(eventsToStore, func() {
@@ -471,6 +493,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -509,6 +534,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -537,6 +565,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -575,6 +606,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -7,8 +7,10 @@ import (
"net"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -44,6 +46,9 @@ import (
const (
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
envBlockPeers = "NB_BLOCK_SAME_PEERS"
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
defaultSyncLim = 1000
)
// GRPCServer an instance of a Management gRPC API server
@@ -63,6 +68,9 @@ type GRPCServer struct {
logBlockedPeers bool
blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator
syncSem atomic.Int32
syncLim int32
}
// NewServer creates a new Management server
@@ -96,6 +104,17 @@ func NewServer(
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
syncLim := int32(defaultSyncLim)
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
syncLimParsed, err := strconv.Atoi(syncLimStr)
if err != nil {
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
} else {
//nolint:gosec
syncLim = int32(syncLimParsed)
}
}
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@@ -110,6 +129,8 @@ func NewServer(
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
syncLim: syncLim,
}, nil
}
@@ -151,6 +172,11 @@ func getRealIP(ctx context.Context) net.IP {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
}
s.syncSem.Add(1)
reqStart := time.Now()
ctx := srv.Context()
@@ -158,6 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
s.syncSem.Add(-1)
return err
}
realIP := getRealIP(ctx)
@@ -172,6 +199,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
s.syncSem.Add(-1)
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
}
}
@@ -183,27 +211,34 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
s.syncSem.Add(-1)
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
s.syncSem.Add(-1)
return err
}
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
start := time.Now()
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
if syncReq.GetMeta() == nil {
@@ -213,21 +248,32 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart))
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return err
}
log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart))
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart))
s.ephemeralManager.OnPeerConnected(ctx, peer)
log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart))
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart))
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
}
@@ -237,6 +283,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
@@ -509,10 +557,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}()
if loginReq.GetMeta() == nil {
@@ -546,9 +600,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
@@ -557,6 +614,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
@@ -712,6 +771,9 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
Fqdn: fqdn,
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: settings.LazyConnectionEnabled,
AutoUpdate: &proto.AutoUpdateSettings{
Version: settings.AutoUpdateVersion,
},
}
}
@@ -719,9 +781,10 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P
response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes),
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache, dnsFwdPort),
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
},
Checks: toProtocolChecks(ctx, checks),
}
@@ -822,10 +885,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request")
}
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)

View File

@@ -0,0 +1,39 @@
package server
import (
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) {
a := am.holder.GetAccount(account.Id)
if a == nil {
am.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
account.NetworkMapCache.UpdateAccountPointer(account)
am.holder.AddAccount(account)
}
func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account {
return am.holder.GetAccount(accountID)
}
func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account {
a := am.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) {
am.holder.AddAccount(account)
}

View File

@@ -4,9 +4,13 @@ import (
"context"
"fmt"
"net/http"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
@@ -38,7 +42,12 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const apiPrefix = "/api"
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(
@@ -58,11 +67,42 @@ func NewAPIHandler(
settingsManager settings.Manager,
) (http.Handler, error) {
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
}
authMiddleware := middleware.NewAuthMiddleware(
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
)
corsMiddleware := cors.AllowAll()

View File

@@ -3,12 +3,15 @@ package accounts
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/netip"
"time"
"github.com/gorilla/mux"
goversion "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/settings"
@@ -26,7 +29,9 @@ const (
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
MinNetworkBitsIPv4 = 28
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
MinNetworkBitsIPv6 = 120
MinNetworkBitsIPv6 = 120
disableAutoUpdate = "disabled"
autoUpdateLatestVersion = "latest"
)
// handler is a handler that handles the server.Account HTTP endpoints
@@ -162,6 +167,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
returnSettings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
}
if req.Settings.Extra != nil {
returnSettings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
}
}
if req.Settings.JwtGroupsEnabled != nil {
returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
}
if req.Settings.GroupsPropagationEnabled != nil {
returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
}
if req.Settings.JwtGroupsClaimName != nil {
returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
if req.Settings.JwtAllowGroups != nil {
returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
}
if req.Settings.DnsDomain != nil {
returnSettings.DNSDomain = *req.Settings.DnsDomain
}
if req.Settings.LazyConnectionEnabled != nil {
returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
}
if req.Settings.AutoUpdateVersion != nil {
_, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion)
if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion ||
*req.Settings.AutoUpdateVersion == disableAutoUpdate ||
err == nil {
returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion
} else if *req.Settings.AutoUpdateVersion != "" {
return nil, fmt.Errorf("invalid AutoUpdateVersion")
}
}
return returnSettings, nil
}
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
@@ -186,45 +246,10 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return
}
settings := &types.Settings{
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
}
if req.Settings.Extra != nil {
settings.Extra = &types.ExtraSettings{
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
}
}
if req.Settings.JwtGroupsEnabled != nil {
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
}
if req.Settings.GroupsPropagationEnabled != nil {
settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
}
if req.Settings.JwtGroupsClaimName != nil {
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
}
if req.Settings.JwtAllowGroups != nil {
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
}
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
}
if req.Settings.DnsDomain != nil {
settings.DNSDomain = *req.Settings.DnsDomain
}
if req.Settings.LazyConnectionEnabled != nil {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
settings, err := h.updateAccountRequestSettings(req)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
@@ -313,6 +338,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
DnsDomain: &settings.DNSDomain,
AutoUpdateVersion: &settings.AutoUpdateVersion,
}
if settings.NetworkRange.IsValid() {

View File

@@ -120,6 +120,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: true,
expectedID: accountID,
@@ -142,6 +143,30 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
},
{
name: "PutAccount OK with autoUpdateVersion",
expectedBody: true,
requestType: http.MethodPut,
requestPath: "/api/accounts/" + accountID,
requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
expectedStatus: http.StatusOK,
expectedSettings: api.AccountSettings{
PeerLoginExpiration: 15552000,
PeerLoginExpirationEnabled: true,
GroupsPropagationEnabled: br(false),
JwtGroupsClaimName: sr(""),
JwtGroupsEnabled: br(false),
JwtAllowGroups: &[]string{},
RegularUsersViewBlocked: false,
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr("latest"),
},
expectedArray: false,
expectedID: accountID,
@@ -164,6 +189,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
@@ -186,6 +212,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,
@@ -208,6 +235,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
RoutingPeerDnsResolutionEnabled: br(false),
LazyConnectionEnabled: br(false),
DnsDomain: sr(""),
AutoUpdateVersion: sr(""),
},
expectedArray: false,
expectedID: accountID,

View File

@@ -29,6 +29,7 @@ type AuthMiddleware struct {
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
}
// NewAuthMiddleware instance constructor
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
return &AuthMiddleware{
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
}
}
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
request, err := m.checkPATFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
err = status.Errorf(status.Unauthorized, "token invalid")
}
util.WriteError(r.Context(), err, w)
return
}
h.ServeHTTP(w, request)
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, fmt.Errorf("error extracting token: %w", err)
}
if m.rateLimiter != nil {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
}
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {

View File

@@ -27,7 +27,9 @@ const (
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
tokenID2 = "tokenID2"
PAT = "nbp_PAT"
PAT2 = "nbp_PAT2"
JWT = "JWT"
wrongToken = "wrongToken"
)
@@ -49,6 +51,15 @@ var testAccount = &types.Account{
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
tokenID2: {
ID: tokenID2,
Name: "My second token",
HashedToken: "someHash2",
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
CreatedBy: userID,
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
},
},
},
@@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
if token == PAT {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
if token == PAT2 {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
@@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
}
func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID {
if token == tokenID || token == tokenID2 {
return nil
}
return fmt.Errorf("Should never get reached")
@@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
}
func TestAuthMiddleware_RateLimiting(t *testing.T) {
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
// Configure rate limiter: 10 requests per minute with burst of 5
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 5,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make burst requests - all should succeed
successCount := 0
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 5, successCount, "All burst requests should succeed")
// The 6th request should fail (exceeded burst)
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
})
t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
// Configure very low rate limit: 1 request per minute
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request should fail (rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
// Configure strict rate limit
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make multiple requests with Bearer token - all should succeed
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
})
t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
// Configure rate limiter
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Use first PAT token
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
// Second request with same token should fail
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
// Use second PAT token - should succeed because it has independent rate limit
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
// Second request with PAT2 should also be rate limited
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
// JWT should still work (not rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
})
t.Run("Rate Limiter Cleanup", func(t *testing.T) {
// Configure rate limiter with short cleanup interval and TTL for testing
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: 100 * time.Millisecond,
LimiterTTL: 200 * time.Millisecond,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request - should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request immediately - should fail (burst exhausted)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
// Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
time.Sleep(400 * time.Millisecond)
// After cleanup, the limiter should be removed and recreated with full burst capacity
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
// Verify it's a fresh limiter by checking burst is reset
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
})
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {
tt := []struct {
name string
@@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
for _, tc := range tt {

View File

@@ -0,0 +1,146 @@
package middleware
import (
"context"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
RequestsPerMinute float64
// Burst defines the maximum number of requests that can be made in a burst
Burst int
// CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
CleanupInterval time.Duration
// LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
LimiterTTL time.Duration
}
// DefaultRateLimiterConfig returns a default configuration
func DefaultRateLimiterConfig() *RateLimiterConfig {
return &RateLimiterConfig{
RequestsPerMinute: 100,
Burst: 120,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// APIRateLimiter manages rate limiting for API tokens
type APIRateLimiter struct {
config *RateLimiterConfig
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
if config == nil {
config = DefaultRateLimiterConfig()
}
rl := &APIRateLimiter{
config: config,
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
go rl.cleanupLoop()
return rl
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
limiter := rl.getLimiter(key)
return limiter.Allow()
}
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
// getLimiter retrieves or creates a rate limiter for the given key
func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.RLock()
entry, exists := rl.limiters[key]
rl.mu.RUnlock()
if exists {
rl.mu.Lock()
entry.lastAccess = time.Now()
rl.mu.Unlock()
return entry.limiter
}
rl.mu.Lock()
defer rl.mu.Unlock()
if entry, exists := rl.limiters[key]; exists {
entry.lastAccess = time.Now()
return entry.limiter
}
requestsPerSecond := rl.config.RequestsPerMinute / 60.0
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
rl.limiters[key] = &limiterEntry{
limiter: limiter,
lastAccess: time.Now(),
}
return limiter
}
// cleanupLoop periodically removes old limiters that haven't been used recently
func (rl *APIRateLimiter) cleanupLoop() {
ticker := time.NewTicker(rl.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
rl.cleanup()
case <-rl.stopChan:
return
}
}
}
// cleanup removes limiters that haven't been used within the TTL period
func (rl *APIRateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, entry := range rl.limiters {
if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
delete(rl.limiters, key)
}
}
}
// Stop stops the cleanup goroutine
func (rl *APIRateLimiter) Stop() {
close(rl.stopChan)
}
// Reset removes the rate limiter for a specific key
func (rl *APIRateLimiter) Reset(key string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.limiters, key)
}

View File

@@ -125,9 +125,10 @@ type MockAccountManager struct {
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
}
return true
}
func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
if am.RecalculateNetworkMapCacheFunc != nil {
return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
}
return nil
}

View File

@@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -0,0 +1,80 @@
package server
import (
"context"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
am.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (am *DefaultAccountManager) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := am.getAccountFromHolderOrInit(accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
}
func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerAddedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := am.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
am.updateAccountInHolder(account)
}
func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if am.experimentalNetworkMap(accountId) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
am.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool {
_, ok := am.expNewNetworkMapAIDs[accountId]
return am.expNewNetworkMap || ok
}

View File

@@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil
@@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil
@@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil
@@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil
@@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -106,11 +106,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start))
}()
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
@@ -145,6 +140,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.BufferUpdateAccountPeers(ctx, accountID)
@@ -321,6 +319,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
if peerLabelChanged || requiresPeerUpdates {
am.UpdateAccountPeers(ctx, accountID)
} else if sshChanged {
@@ -381,6 +383,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if am.experimentalNetworkMap(accountID) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
}
}
if userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -417,7 +431,13 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
return nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(peer.AccountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -690,6 +710,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if am.experimentalNetworkMap(accountID) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
@@ -708,11 +739,6 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start))
}()
var peer *nbpeer.Peer
var peerNotValid bool
var isStatusChanged bool
@@ -776,6 +802,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -831,6 +860,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
@@ -900,8 +930,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
startBuffer := time.Now()
am.BufferUpdateAccountPeers(ctx, accountID)
log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer))
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
@@ -997,11 +1034,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -1014,9 +1046,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, emptyMap, nil, nil
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
@@ -1024,10 +1064,12 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err
}
startPosture := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
@@ -1037,7 +1079,13 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -1167,11 +1215,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
}
globalStart := time.Now()
@@ -1204,6 +1259,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if am.experimentalNetworkMap(accountID) {
am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -1241,7 +1300,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
start = time.Now()
@@ -1257,7 +1322,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
}
@@ -1351,7 +1416,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics())
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -1368,7 +1439,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1575,12 +1646,12 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
PeerConfig: toPeerConfig(peer, network, dnsDomain, settings),
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
NetworkMap: &types.NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {

View File

@@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) {
}
func TestAccountManager_GetNetworkMap(t *testing.T) {
testGetNetworkMapGeneral(t)
}
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
func TestUpdateAccountPeers(t *testing.T) {
testUpdateAccountPeers(t)
}
func testUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
@@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) {
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
}
})
}
@@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
}
func Test_LoginPeer(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}

View File

@@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
@@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
PolicyID: "RuleDefault",
},
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.250.202",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.250.202",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
@@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.21.56",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",

View File

@@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -5,6 +5,9 @@ package settings
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
@@ -45,6 +48,11 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
}
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
}()
if userID != activity.SystemInitiator {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,951 @@
package store
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sort"
"sync"
"testing"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
)
func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
var account types.Account
result := s.db.Model(&account).
Omit("GroupsG").
Preload("UsersG.PATsG"). // have to be specified as this is nested reference
Preload(clause.Associations).
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*types.PolicyRule
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
}
account.PeersG = nil
account.Users = make(map[string]*types.User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
var groupPeers []types.GroupPeer
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
Find(&groupPeers)
for _, groupPeer := range groupPeers {
if group, ok := account.Groups[groupPeer.GroupID]; ok {
group.Peers = append(group.Peers, groupPeer.PeerID)
} else {
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
}
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
}
account.NameServerGroupsG = nil
return &account, nil
}
func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
var account types.Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specified as this is nested reference
Preload("Policies.Rules").
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
Preload("Networks").
Preload("NetworkRouters").
Preload("NetworkResources").
Preload("Onboarding").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
if key.AutoGroups == nil {
key.AutoGroups = []string{}
}
account.SetupKeys[key.Key] = &key
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = &peer
}
account.PeersG = nil
account.Users = make(map[string]*types.User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
if user.AutoGroups == nil {
user.AutoGroups = []string{}
}
account.Users[user.Id] = &user
user.PATsG = nil
}
account.UsersG = nil
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
group.Peers = make([]string, len(group.GroupPeers))
for i, gp := range group.GroupPeers {
group.Peers[i] = gp.PeerID
}
if group.Resources == nil {
group.Resources = []types.Resource{}
}
account.Groups[group.ID] = group
}
account.GroupsG = nil
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = &route
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
ns.AccountID = ""
if ns.NameServers == nil {
ns.NameServers = []nbdns.NameServer{}
}
if ns.Groups == nil {
ns.Groups = []string{}
}
if ns.Domains == nil {
ns.Domains = []string{}
}
account.NameServerGroups[ns.ID] = &ns
}
account.NameServerGroupsG = nil
return &account, nil
}
func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("unable to parse database config: %w", err)
}
config.MaxConns = 12
config.MinConns = 2
config.MaxConnLifetime = time.Hour
config.HealthCheckPeriod = time.Minute
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("unable to ping database: %w", err)
}
return pool, nil
}
func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
cleanup, dsn, err := testutil.CreatePostgresTestContainer()
if err != nil {
b.Fatalf("failed to create test container: %v", err)
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
b.Fatalf("failed to connect database: %v", err)
}
pool, err := connectDBforTest(context.Background(), dsn)
if err != nil {
b.Fatalf("failed to connect database: %v", err)
}
models := []interface{}{
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{},
}
for i := len(models) - 1; i >= 0; i-- {
err := db.Migrator().DropTable(models[i])
if err != nil {
b.Fatalf("failed to drop table: %v", err)
}
}
err = db.AutoMigrate(models...)
if err != nil {
b.Fatalf("failed to migrate database: %v", err)
}
store := &SqlStore{
db: db,
pool: pool,
}
const (
accountID = "benchmark-account-id"
numUsers = 20
numPatsPerUser = 3
numSetupKeys = 25
numPeers = 200
numGroups = 30
numPolicies = 50
numRulesPerPolicy = 10
numRoutes = 40
numNSGroups = 10
numPostureChecks = 15
numNetworks = 5
numNetworkRouters = 5
numNetworkResources = 10
)
_, ipNet, _ := net.ParseCIDR("100.64.0.0/10")
acc := types.Account{
Id: accountID,
CreatedBy: "benchmark-user",
CreatedAt: time.Now(),
Domain: "benchmark.com",
IsDomainPrimaryAccount: true,
Network: &types.Network{
Identifier: "benchmark-net",
Net: *ipNet,
Serial: 1,
},
DNSSettings: types.DNSSettings{
DisabledManagementGroups: []string{"group-disabled-1"},
},
Settings: &types.Settings{},
}
if err := db.Create(&acc).Error; err != nil {
b.Fatalf("create account: %v", err)
}
var setupKeys []types.SetupKey
for i := 0; i < numSetupKeys; i++ {
setupKeys = append(setupKeys, types.SetupKey{
Id: fmt.Sprintf("keyid-%d", i),
AccountID: accountID,
Key: fmt.Sprintf("key-%d", i),
Name: fmt.Sprintf("Benchmark Key %d", i),
ExpiresAt: &time.Time{},
})
}
if err := db.Create(&setupKeys).Error; err != nil {
b.Fatalf("create setup keys: %v", err)
}
var peers []nbpeer.Peer
for i := 0; i < numPeers; i++ {
peers = append(peers, nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
AccountID: accountID,
Key: fmt.Sprintf("peerkey-%d", i),
IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)),
Name: fmt.Sprintf("peer-name-%d", i),
Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()},
})
}
if err := db.Create(&peers).Error; err != nil {
b.Fatalf("create peers: %v", err)
}
for i := 0; i < numUsers; i++ {
userID := fmt.Sprintf("user-%d", i)
user := types.User{Id: userID, AccountID: accountID}
if err := db.Create(&user).Error; err != nil {
b.Fatalf("create user %s: %v", userID, err)
}
var pats []types.PersonalAccessToken
for j := 0; j < numPatsPerUser; j++ {
pats = append(pats, types.PersonalAccessToken{
ID: fmt.Sprintf("pat-%d-%d", i, j),
UserID: userID,
Name: fmt.Sprintf("PAT %d for User %d", j, i),
})
}
if err := db.Create(&pats).Error; err != nil {
b.Fatalf("create pats for user %s: %v", userID, err)
}
}
var groups []*types.Group
for i := 0; i < numGroups; i++ {
groups = append(groups, &types.Group{
ID: fmt.Sprintf("group-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Group %d", i),
})
}
if err := db.Create(&groups).Error; err != nil {
b.Fatalf("create groups: %v", err)
}
for i := 0; i < numPolicies; i++ {
policyID := fmt.Sprintf("policy-%d", i)
policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true}
if err := db.Create(&policy).Error; err != nil {
b.Fatalf("create policy %s: %v", policyID, err)
}
var rules []*types.PolicyRule
for j := 0; j < numRulesPerPolicy; j++ {
rules = append(rules, &types.PolicyRule{
ID: fmt.Sprintf("rule-%d-%d", i, j),
PolicyID: policyID,
Name: fmt.Sprintf("Rule %d for Policy %d", j, i),
Enabled: true,
Protocol: "all",
})
}
if err := db.Create(&rules).Error; err != nil {
b.Fatalf("create rules for policy %s: %v", policyID, err)
}
}
var routes []route.Route
for i := 0; i < numRoutes; i++ {
routes = append(routes, route.Route{
ID: route.ID(fmt.Sprintf("route-%d", i)),
AccountID: accountID,
Description: fmt.Sprintf("Route %d", i),
Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)),
Enabled: true,
})
}
if err := db.Create(&routes).Error; err != nil {
b.Fatalf("create routes: %v", err)
}
var nsGroups []nbdns.NameServerGroup
for i := 0; i < numNSGroups; i++ {
nsGroups = append(nsGroups, nbdns.NameServerGroup{
ID: fmt.Sprintf("nsg-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("NS Group %d", i),
Description: "Benchmark NS Group",
Enabled: true,
})
}
if err := db.Create(&nsGroups).Error; err != nil {
b.Fatalf("create nsgroups: %v", err)
}
var postureChecks []*posture.Checks
for i := 0; i < numPostureChecks; i++ {
postureChecks = append(postureChecks, &posture.Checks{
ID: fmt.Sprintf("pc-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Posture Check %d", i),
})
}
if err := db.Create(&postureChecks).Error; err != nil {
b.Fatalf("create posture checks: %v", err)
}
var networks []*networkTypes.Network
for i := 0; i < numNetworks; i++ {
networks = append(networks, &networkTypes.Network{
ID: fmt.Sprintf("nettype-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Network Type %d", i),
})
}
if err := db.Create(&networks).Error; err != nil {
b.Fatalf("create networks: %v", err)
}
var networkRouters []*routerTypes.NetworkRouter
for i := 0; i < numNetworkRouters; i++ {
networkRouters = append(networkRouters, &routerTypes.NetworkRouter{
ID: fmt.Sprintf("router-%d", i),
AccountID: accountID,
NetworkID: networks[i%numNetworks].ID,
Peer: peers[i%numPeers].ID,
})
}
if err := db.Create(&networkRouters).Error; err != nil {
b.Fatalf("create network routers: %v", err)
}
var networkResources []*resourceTypes.NetworkResource
for i := 0; i < numNetworkResources; i++ {
networkResources = append(networkResources, &resourceTypes.NetworkResource{
ID: fmt.Sprintf("resource-%d", i),
AccountID: accountID,
NetworkID: networks[i%numNetworks].ID,
Name: fmt.Sprintf("Resource %d", i),
})
}
if err := db.Create(&networkResources).Error; err != nil {
b.Fatalf("create network resources: %v", err)
}
onboarding := types.AccountOnboarding{
AccountID: accountID,
OnboardingFlowPending: true,
}
if err := db.Create(&onboarding).Error; err != nil {
b.Fatalf("create onboarding: %v", err)
}
return store, cleanup, accountID
}
func BenchmarkGetAccount(b *testing.B) {
store, cleanup, accountID := setupBenchmarkDB(b)
defer cleanup()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
b.Run("old", func(b *testing.B) {
for range b.N {
_, err := store.GetAccountSlow(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountSlow failed: %v", err)
}
}
})
b.Run("gorm opt", func(b *testing.B) {
for range b.N {
_, err := store.GetAccountGormOpt(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountFast failed: %v", err)
}
}
})
b.Run("raw", func(b *testing.B) {
for range b.N {
_, err := store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountPureSQL failed: %v", err)
}
}
})
store.pool.Close()
}
func TestAccountEquivalence(t *testing.T) {
store, cleanup, accountID := setupBenchmarkDB(t)
defer cleanup()
ctx := context.Background()
type getAccountFunc func(context.Context, string) (*types.Account, error)
tests := []struct {
name string
expectedF getAccountFunc
actualF getAccountFunc
}{
{"old vs new", store.GetAccountSlow, store.GetAccountGormOpt},
{"old vs raw", store.GetAccountSlow, store.GetAccount},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expected, errOld := tt.expectedF(ctx, accountID)
assert.NoError(t, errOld, "expected function should not return an error")
assert.NotNil(t, expected, "expected should not be nil")
actual, errNew := tt.actualF(ctx, accountID)
assert.NoError(t, errNew, "actual function should not return an error")
assert.NotNil(t, actual, "actual should not be nil")
testAccountEquivalence(t, expected, actual)
})
}
expected, errOld := store.GetAccountSlow(ctx, accountID)
assert.NoError(t, errOld, "GetAccountSlow should not return an error")
assert.NotNil(t, expected, "expected should not be nil")
actual, errNew := store.GetAccount(ctx, accountID)
assert.NoError(t, errNew, "GetAccount (new) should not return an error")
assert.NotNil(t, actual, "actual should not be nil")
}
func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal")
assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal")
assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second")
assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal")
assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal")
assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal")
assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal")
assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal")
assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal")
assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements")
for key, oldVal := range expected.SetupKeys {
newVal, ok := actual.SetupKeys[key]
assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key)
}
assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements")
for key, oldVal := range expected.Peers {
newVal, ok := actual.Peers[key]
assert.True(t, ok, "Peer with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key)
}
assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements")
for key, oldUser := range expected.Users {
newUser, ok := actual.Users[key]
assert.True(t, ok, "User with ID '%s' should exist in new account", key)
assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key)
for patKey, oldPAT := range oldUser.PATs {
newPAT, patOk := newUser.PATs[patKey]
assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key)
assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key)
}
oldUser.PATs = nil
newUser.PATs = nil
assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key)
}
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
for key, oldVal := range expected.Groups {
newVal, ok := actual.Groups[key]
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
sort.Strings(oldVal.Peers)
sort.Strings(newVal.Peers)
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
}
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
for key, oldVal := range expected.Routes {
newVal, ok := actual.Routes[key]
assert.True(t, ok, "Route with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key)
}
assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements")
for key, oldVal := range expected.NameServerGroups {
newVal, ok := actual.NameServerGroups[key]
assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key)
}
assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements")
sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID })
sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID })
for i := range expected.Policies {
sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID })
sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID })
assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID)
}
assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements")
sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID })
sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID })
for i := range expected.PostureChecks {
assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID)
}
assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements")
sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID })
sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID })
for i := range expected.Networks {
assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID)
}
assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements")
sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID })
sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID })
for i := range expected.NetworkRouters {
assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID)
}
assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements")
sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID })
sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID })
for i := range expected.NetworkResources {
assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID)
}
}
func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) {
account, err := s.getAccount(ctx, accountID)
if err != nil {
return nil, err
}
var wg sync.WaitGroup
errChan := make(chan error, 12)
wg.Add(1)
go func() {
defer wg.Done()
keys, err := s.getSetupKeys(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.SetupKeysG = keys
}()
wg.Add(1)
go func() {
defer wg.Done()
peers, err := s.getPeers(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PeersG = peers
}()
wg.Add(1)
go func() {
defer wg.Done()
users, err := s.getUsers(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.UsersG = users
}()
wg.Add(1)
go func() {
defer wg.Done()
groups, err := s.getGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.GroupsG = groups
}()
wg.Add(1)
go func() {
defer wg.Done()
policies, err := s.getPolicies(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Policies = policies
}()
wg.Add(1)
go func() {
defer wg.Done()
routes, err := s.getRoutes(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.RoutesG = routes
}()
wg.Add(1)
go func() {
defer wg.Done()
nsgs, err := s.getNameServerGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NameServerGroupsG = nsgs
}()
wg.Add(1)
go func() {
defer wg.Done()
checks, err := s.getPostureChecks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PostureChecks = checks
}()
wg.Add(1)
go func() {
defer wg.Done()
networks, err := s.getNetworks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Networks = networks
}()
wg.Add(1)
go func() {
defer wg.Done()
routers, err := s.getNetworkRouters(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkRouters = routers
}()
wg.Add(1)
go func() {
defer wg.Done()
resources, err := s.getNetworkResources(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkResources = resources
}()
wg.Add(1)
go func() {
defer wg.Done()
err := s.getAccountOnboarding(ctx, accountID, account)
if err != nil {
errChan <- err
return
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
var userIDs []string
for _, u := range account.UsersG {
userIDs = append(userIDs, u.Id)
}
var policyIDs []string
for _, p := range account.Policies {
policyIDs = append(policyIDs, p.ID)
}
var groupIDs []string
for _, g := range account.GroupsG {
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
var pats []types.PersonalAccessToken
go func() {
defer wg.Done()
var err error
pats, err = s.getPersonalAccessTokens(ctx, userIDs)
if err != nil {
errChan <- err
}
}()
var rules []*types.PolicyRule
go func() {
defer wg.Done()
var err error
rules, err = s.getPolicyRules(ctx, policyIDs)
if err != nil {
errChan <- err
}
}()
var groupPeers []types.GroupPeer
go func() {
defer wg.Done()
var err error
groupPeers, err = s.getGroupPeers(ctx, groupIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
patsByUserID := make(map[string][]*types.PersonalAccessToken)
for i := range pats {
pat := &pats[i]
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
pat.UserID = ""
}
rulesByPolicyID := make(map[string][]*types.PolicyRule)
for _, rule := range rules {
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
}
peersByGroupID := make(map[string][]string)
for _, gp := range groupPeers {
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
account.SetupKeys[key.Key] = key
}
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for i := range account.PeersG {
peer := &account.PeersG[i]
account.Peers[peer.ID] = peer
}
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
user.PATs = make(map[string]*types.PersonalAccessToken)
if userPats, ok := patsByUserID[user.Id]; ok {
for j := range userPats {
pat := userPats[j]
user.PATs[pat.ID] = pat
}
}
account.Users[user.Id] = user
}
for i := range account.Policies {
policy := account.Policies[i]
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
policy.Rules = policyRules
}
}
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for i := range account.GroupsG {
group := account.GroupsG[i]
if peerIDs, ok := peersByGroupID[group.ID]; ok {
group.Peers = peerIDs
}
account.Groups[group.ID] = group
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for i := range account.RoutesG {
route := &account.RoutesG[i]
account.Routes[route.ID] = route
}
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for i := range account.NameServerGroupsG {
nsg := &account.NameServerGroupsG[i]
nsg.AccountID = ""
account.NameServerGroups[nsg.ID] = nsg
}
account.SetupKeysG = nil
account.PeersG = nil
account.UsersG = nil
account.GroupsG = nil
account.RoutesG = nil
account.NameServerGroupsG = nil
return account, nil
}

View File

@@ -468,6 +468,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
closeConnection := func() {
cleanup()
store.Close(ctx)
if store.pool != nil {
store.pool.Close()
}
}
return store, closeConnection, nil
@@ -487,12 +490,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
db, err := openDBWithRetry(dsn, kind, 5)
if err != nil {
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
}
dsn, cleanup, err := createRandomDB(dsn, db, kind)
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
if err != nil {
return nil, nil, err
}
@@ -519,12 +528,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
db, err := openDBWithRetry(dsn, kind, 5)
if err != nil {
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err)
}
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
dsn, cleanup, err := createRandomDB(dsn, db, kind)
sqlDB.Close()
if err != nil {
return nil, nil, err
}
@@ -537,6 +556,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
return store, cleanup, nil
}
func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) {
var db *gorm.DB
var err error
for i := range maxRetries {
switch engine {
case types.PostgresStoreEngine:
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case types.MysqlStoreEngine:
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
}
if err == nil {
return db, nil
}
if i < maxRetries-1 {
waitTime := time.Duration(100*(i+1)) * time.Millisecond
time.Sleep(waitTime)
}
}
return nil, err
}
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
@@ -544,21 +588,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(
return "", nil, fmt.Errorf("failed to create database: %v", err)
}
var err error
originalDSN := dsn
cleanup := func() {
var dropDB *gorm.DB
var err error
switch engine {
case types.PostgresStoreEngine:
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{
SkipDefaultTransaction: true,
PrepareStmt: false,
})
if err != nil {
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
return
}
defer func() {
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.Close()
}
}()
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(0)
sqlDB.SetConnMaxLifetime(time.Second)
}
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error
case types.MysqlStoreEngine:
// err = killMySQLConnections(dsn, dbName)
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{
SkipDefaultTransaction: true,
PrepareStmt: false,
})
if err != nil {
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
return
}
defer func() {
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.Close()
}
}()
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(0)
sqlDB.SetConnMaxLifetime(time.Second)
}
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error
}
if err != nil {
log.Errorf("failed to drop database %s: %v", dbName, err)
panic(err)
}
sqlDB, _ := db.DB()
_ = sqlDB.Close()
}
return replaceDBName(dsn, dbName), cleanup, nil

View File

@@ -8,6 +8,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
@@ -87,6 +88,13 @@ type Account struct {
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
nmapInitOnce *sync.Once `gorm:"-"`
}
func (a *Account) InitOnce() {
a.nmapInitOnce = &sync.Once{}
}
// this class is used by gorm only
@@ -257,7 +265,6 @@ func (a *Account) GetPeerNetworkMap(
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@@ -890,6 +897,8 @@ func (a *Account) Copy() *Account {
NetworkRouters: networkRouters,
NetworkResources: networkResources,
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
}
}
@@ -1049,14 +1058,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
all, err := a.GetGroupAll()
if err != nil {
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
all = &Group{}
}
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
@@ -1075,10 +1077,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
Protocol: string(rule.Protocol),
}
if isAll {
fr.PeerIP = "0.0.0.0"
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {

View File

@@ -0,0 +1,43 @@
package types
import (
"context"
"sync"
)
type Holder struct {
mu sync.RWMutex
accounts map[string]*Account
}
func NewHolder() *Holder {
return &Holder{
accounts: make(map[string]*Account),
}
}
func (h *Holder) GetAccount(id string) *Account {
h.mu.RLock()
defer h.mu.RUnlock()
return h.accounts[id]
}
func (h *Holder) AddAccount(account *Account) {
h.mu.Lock()
defer h.mu.Unlock()
h.accounts[account.Id] = account
}
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
h.mu.Lock()
defer h.mu.Unlock()
if acc, ok := h.accounts[id]; ok {
return acc, nil
}
account, err := accGetter(context.Background(), id)
if err != nil {
return nil, err
}
h.accounts[id] = account
return account, nil
}

View File

@@ -0,0 +1,58 @@
package types
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
if a.NetworkMapCache != nil {
return
}
a.nmapInitOnce.Do(func() {
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
})
}
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}
func (a *Account) GetPeerNetworkMapExp(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeers map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
a.initNetworkMapBuilder(validatedPeers)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
}
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
}
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerDeleted(peerId)
}
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.UpdatePeer(peer)
}
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}

Some files were not shown because too many files have changed in this diff Show More