mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
Compare commits
27 Commits
v0.59.9
...
vk/debug/n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8a9af2482 | ||
|
|
4bcda3e2ba | ||
|
|
c28275611b | ||
|
|
56f169eede | ||
|
|
b780f1c09d | ||
|
|
07cf9d5895 | ||
|
|
7df49e249d | ||
|
|
dbfc8a52c9 | ||
|
|
98ddac07bf | ||
|
|
48475ddc05 | ||
|
|
6aa4ba7af4 | ||
|
|
2e16c9914a | ||
|
|
5c29d395b2 | ||
|
|
229e0038ee | ||
|
|
75327d9519 | ||
|
|
c92e6c1b5f | ||
|
|
641eb5140b | ||
|
|
45c25dca84 | ||
|
|
679c58ce47 | ||
|
|
719283c792 | ||
|
|
a2313a5ba4 | ||
|
|
8c108ccad3 | ||
|
|
86eff0d750 | ||
|
|
43c9a51913 | ||
|
|
c530db1455 | ||
|
|
1ee575befe | ||
|
|
d3a34adcc9 |
@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
||||
}
|
||||
|
||||
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||
loginRequest.OptionalPreSharedKey = &preSharedKey
|
||||
}
|
||||
@@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error {
|
||||
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
|
||||
needsLogin := false
|
||||
|
||||
err := WithBackOff(func() error {
|
||||
@@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
|
||||
jwtToken := ""
|
||||
if setupKey == "" && needsLogin {
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config)
|
||||
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interactive sso login failed: %v", err)
|
||||
}
|
||||
@@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
||||
return nil
|
||||
}
|
||||
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
|
||||
hint := ""
|
||||
pm := profilemanager.NewProfileManager()
|
||||
profileState, err := pm.GetProfileState(profileName)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
hint = profileState.Email
|
||||
}
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/kardianos/service"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
@@ -81,6 +83,10 @@ func configurePlatformSpecificSettings(svcConfig *service.Config) error {
|
||||
svcConfig.Option["LogDirectory"] = dir
|
||||
}
|
||||
}
|
||||
|
||||
if err := configureSystemdNetworkd(); err != nil {
|
||||
log.Warnf("failed to configure systemd-networkd: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
@@ -160,6 +166,12 @@ var uninstallCmd = &cobra.Command{
|
||||
return fmt.Errorf("uninstall service: %w", err)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
if err := cleanupSystemdNetworkd(); err != nil {
|
||||
log.Warnf("failed to cleanup systemd-networkd configuration: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Println("NetBird service has been uninstalled")
|
||||
return nil
|
||||
},
|
||||
@@ -245,3 +257,50 @@ func isServiceRunning() (bool, error) {
|
||||
|
||||
return status == service.StatusRunning, nil
|
||||
}
|
||||
|
||||
const (
|
||||
networkdConf = "/etc/systemd/networkd.conf"
|
||||
networkdConfDir = "/etc/systemd/networkd.conf.d"
|
||||
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
|
||||
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
|
||||
# routes and policy rules managed by NetBird.
|
||||
|
||||
[Network]
|
||||
ManageForeignRoutes=no
|
||||
ManageForeignRoutingPolicyRules=no
|
||||
`
|
||||
)
|
||||
|
||||
// configureSystemdNetworkd creates a drop-in configuration file to prevent
|
||||
// systemd-networkd from removing NetBird's routes and policy rules.
|
||||
func configureSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
|
||||
log.Debug("systemd-networkd not in use, skipping configuration")
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
|
||||
return fmt.Errorf("create networkd.conf.d directory: %w", err)
|
||||
}
|
||||
|
||||
// nolint:gosec // standard networkd permissions
|
||||
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
|
||||
return fmt.Errorf("write networkd configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSystemdNetworkd removes the NetBird systemd-networkd configuration file.
|
||||
func cleanupSystemdNetworkd() error {
|
||||
if _, err := os.Stat(networkdConfFile); os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := os.Remove(networkdConfFile); err != nil {
|
||||
return fmt.Errorf("remove networkd configuration: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
|
||||
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginRequest.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
var loginErr error
|
||||
var loginResp *proto.LoginResponse
|
||||
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
)
|
||||
|
||||
// NewFirewall creates a firewall manager instance
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
if !iface.IsUserspaceBind() {
|
||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||
}
|
||||
|
||||
// use userspace packet filtering firewall
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -34,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
||||
// FWType is the type for the firewall type
|
||||
type FWType int
|
||||
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool, mtu uint16) (firewall.Manager, error) {
|
||||
// on the linux system we try to user nftables or iptables
|
||||
// in any case, because we need to allow netbird interface traffic
|
||||
// so we use AllowNetbird traffic from these firewall managers
|
||||
// for the userspace packet filtering firewall
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes, mtu)
|
||||
|
||||
if !iface.IsUserspaceBind() {
|
||||
return fm, err
|
||||
@@ -48,11 +48,11 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
|
||||
if err != nil {
|
||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||
}
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface)
|
||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool, mtu uint16) (firewall.Manager, error) {
|
||||
fm, err := createFW(iface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create firewall: %s", err)
|
||||
}
|
||||
@@ -64,26 +64,26 @@ func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager,
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
||||
func createFW(iface IFaceMapper, mtu uint16) (firewall.Manager, error) {
|
||||
switch check() {
|
||||
case IPTABLES:
|
||||
log.Info("creating an iptables firewall manager")
|
||||
return nbiptables.Create(iface)
|
||||
return nbiptables.Create(iface, mtu)
|
||||
case NFTABLES:
|
||||
log.Info("creating an nftables firewall manager")
|
||||
return nbnftables.Create(iface)
|
||||
return nbnftables.Create(iface, mtu)
|
||||
default:
|
||||
log.Info("no firewall manager found, trying to use userspace packet filtering firewall")
|
||||
return nil, errors.New("no firewall manager found")
|
||||
}
|
||||
}
|
||||
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (firewall.Manager, error) {
|
||||
var errUsp error
|
||||
if fm != nil {
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger, mtu)
|
||||
} else {
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
if errUsp != nil {
|
||||
|
||||
@@ -36,7 +36,7 @@ type iFaceMapper interface {
|
||||
}
|
||||
|
||||
// Create iptables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init iptables: %w", err)
|
||||
@@ -47,7 +47,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
ipv4Client: iptablesClient,
|
||||
}
|
||||
|
||||
m.router, err = newRouter(iptablesClient, wgIface)
|
||||
m.router, err = newRouter(iptablesClient, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
@@ -66,6 +66,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}
|
||||
stateManager.RegisterState(state)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -53,7 +54,7 @@ func TestIptablesManager(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -114,7 +115,7 @@ func TestIptablesManagerDenyRules(t *testing.T) {
|
||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -198,7 +199,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
||||
}
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -264,7 +265,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -30,17 +30,20 @@ const (
|
||||
|
||||
chainPOSTROUTING = "POSTROUTING"
|
||||
chainPREROUTING = "PREROUTING"
|
||||
chainFORWARD = "FORWARD"
|
||||
chainRTNAT = "NETBIRD-RT-NAT"
|
||||
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||
chainRTPRE = "NETBIRD-RT-PRE"
|
||||
chainRTRDR = "NETBIRD-RT-RDR"
|
||||
chainRTMSSCLAMP = "NETBIRD-RT-MSSCLAMP"
|
||||
routingFinalForwardJump = "ACCEPT"
|
||||
routingFinalNatJump = "MASQUERADE"
|
||||
|
||||
jumpManglePre = "jump-mangle-pre"
|
||||
jumpNatPre = "jump-nat-pre"
|
||||
jumpNatPost = "jump-nat-post"
|
||||
jumpMSSClamp = "jump-mss-clamp"
|
||||
markManglePre = "mark-mangle-pre"
|
||||
markManglePost = "mark-mangle-post"
|
||||
matchSet = "--match-set"
|
||||
@@ -48,6 +51,9 @@ const (
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
fwdSuffix = "_fwd"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
type ruleInfo struct {
|
||||
@@ -77,16 +83,18 @@ type router struct {
|
||||
ipsetCounter *ipsetCounter
|
||||
wgIface iFaceMapper
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
|
||||
stateManager *statemanager.Manager
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
}
|
||||
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
r := &router{
|
||||
iptablesClient: iptablesClient,
|
||||
rules: make(map[string][]string),
|
||||
wgIface: wgIface,
|
||||
mtu: mtu,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
}
|
||||
|
||||
@@ -392,6 +400,7 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||
if err != nil {
|
||||
@@ -416,6 +425,7 @@ func (r *router) createContainers() error {
|
||||
{chainRTPRE, tableMangle},
|
||||
{chainRTNAT, tableNat},
|
||||
{chainRTRDR, tableNat},
|
||||
{chainRTMSSCLAMP, tableMangle},
|
||||
} {
|
||||
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||
@@ -438,6 +448,10 @@ func (r *router) createContainers() error {
|
||||
return fmt.Errorf("add jump rules: %w", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -518,6 +532,35 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
// Add jump rule from FORWARD chain in mangle table to our custom chain
|
||||
jumpRule := []string{
|
||||
"-j", chainRTMSSCLAMP,
|
||||
}
|
||||
if err := r.iptablesClient.Insert(tableMangle, chainFORWARD, 1, jumpRule...); err != nil {
|
||||
return fmt.Errorf("add jump to MSS clamp chain: %w", err)
|
||||
}
|
||||
r.rules[jumpMSSClamp] = jumpRule
|
||||
|
||||
ruleOut := []string{
|
||||
"-o", r.wgIface.Name(),
|
||||
"-p", "tcp",
|
||||
"--tcp-flags", "SYN,RST", "SYN",
|
||||
"-j", "TCPMSS",
|
||||
"--set-mss", fmt.Sprintf("%d", mss),
|
||||
}
|
||||
if err := r.iptablesClient.Append(tableMangle, chainRTMSSCLAMP, ruleOut...); err != nil {
|
||||
return fmt.Errorf("add outbound MSS clamp rule: %w", err)
|
||||
}
|
||||
r.rules["mss-clamp-out"] = ruleOut
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) insertEstablishedRule(chain string) error {
|
||||
establishedRule := getConntrackEstablished()
|
||||
|
||||
@@ -558,7 +601,7 @@ func (r *router) addJumpRules() error {
|
||||
}
|
||||
|
||||
func (r *router) cleanJumpRules() error {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre, jumpMSSClamp} {
|
||||
if rule, exists := r.rules[ruleKey]; exists {
|
||||
var table, chain string
|
||||
switch ruleKey {
|
||||
@@ -571,6 +614,9 @@ func (r *router) cleanJumpRules() error {
|
||||
case jumpNatPre:
|
||||
table = tableNat
|
||||
chain = chainPREROUTING
|
||||
case jumpMSSClamp:
|
||||
table = tableMangle
|
||||
chain = chainFORWARD
|
||||
default:
|
||||
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
@@ -30,7 +31,7 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "should return a valid iptables manager")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -38,7 +39,6 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
assert.NoError(t, manager.Reset(), "shouldn't return error")
|
||||
}()
|
||||
|
||||
// Now 5 rules:
|
||||
// 1. established rule forward in
|
||||
// 2. estbalished rule forward out
|
||||
// 3. jump rule to POST nat chain
|
||||
@@ -48,7 +48,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
||||
// 7. static return masquerade rule
|
||||
// 8. mangle prerouting mark rule
|
||||
// 9. mangle postrouting mark rule
|
||||
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||
// 10. jump rule to MSS clamping chain
|
||||
// 11. MSS clamping rule for outbound traffic
|
||||
require.Len(t, manager.rules, 11, "should have created rules map")
|
||||
|
||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||
@@ -82,7 +84,7 @@ func TestIptablesManager_AddNatRule(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "failed to init iptables client")
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
|
||||
@@ -155,7 +157,7 @@ func TestIptablesManager_RemoveNatRule(t *testing.T) {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
iptablesClient, _ := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
|
||||
manager, err := newRouter(iptablesClient, ifaceMock)
|
||||
manager, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "shouldn't return error")
|
||||
require.NoError(t, manager.init(nil))
|
||||
defer func() {
|
||||
@@ -217,7 +219,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
iptablesClient, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
||||
require.NoError(t, err, "Failed to create iptables client")
|
||||
|
||||
r, err := newRouter(iptablesClient, ifaceMock)
|
||||
r, err := newRouter(iptablesClient, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router manager")
|
||||
require.NoError(t, r.init(nil))
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -11,6 +12,7 @@ type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -42,7 +44,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
ipt, err := Create(s.InterfaceState)
|
||||
mtu := s.InterfaceState.MTU
|
||||
if mtu == 0 {
|
||||
mtu = iface.DefaultMTU
|
||||
}
|
||||
ipt, err := Create(s.InterfaceState, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create iptables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -100,6 +100,9 @@ type Manager interface {
|
||||
//
|
||||
// If comment argument is empty firewall manager should set
|
||||
// rule ID as comment for the rule
|
||||
//
|
||||
// Note: Callers should call Flush() after adding rules to ensure
|
||||
// they are applied to the kernel and rule handles are refreshed.
|
||||
AddPeerFiltering(
|
||||
id []byte,
|
||||
ip net.IP,
|
||||
|
||||
@@ -29,8 +29,6 @@ const (
|
||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||
|
||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||
)
|
||||
|
||||
const flushError = "flush: %w"
|
||||
@@ -195,25 +193,6 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
||||
// createDefaultAllowRules creates default allow rules for the input and output chains
|
||||
func (m *AclManager) createDefaultAllowRules() error {
|
||||
expIn := []expr.Any{
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseNetworkHeader,
|
||||
Offset: 12,
|
||||
Len: 4,
|
||||
},
|
||||
// mask
|
||||
&expr.Bitwise{
|
||||
SourceRegister: 1,
|
||||
DestRegister: 1,
|
||||
Len: 4,
|
||||
Mask: []byte{0, 0, 0, 0},
|
||||
Xor: []byte{0, 0, 0, 0},
|
||||
},
|
||||
// net address
|
||||
&expr.Cmp{
|
||||
Register: 1,
|
||||
Data: []byte{0, 0, 0, 0},
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
@@ -258,7 +237,7 @@ func (m *AclManager) addIOFiltering(
|
||||
action firewall.Action,
|
||||
ipset *nftables.Set,
|
||||
) (*Rule, error) {
|
||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||
ruleId := generatePeerRuleId(ip, proto, sPort, dPort, action, ipset)
|
||||
if r, ok := m.rules[ruleId]; ok {
|
||||
return &Rule{
|
||||
nftRule: r.nftRule,
|
||||
@@ -357,11 +336,12 @@ func (m *AclManager) addIOFiltering(
|
||||
}
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return nil, fmt.Errorf(flushError, err)
|
||||
return nil, fmt.Errorf("flush input rule %s: %v", ruleId, err)
|
||||
}
|
||||
|
||||
ruleStruct := &Rule{
|
||||
nftRule: nftRule,
|
||||
nftRule: nftRule,
|
||||
// best effort mangle rule
|
||||
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||
nftSet: ipset,
|
||||
ruleID: ruleId,
|
||||
@@ -420,12 +400,19 @@ func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byt
|
||||
},
|
||||
)
|
||||
|
||||
return m.rConn.AddRule(&nftables.Rule{
|
||||
nfRule := m.rConn.AddRule(&nftables.Rule{
|
||||
Table: m.workTable,
|
||||
Chain: m.chainPrerouting,
|
||||
Exprs: preroutingExprs,
|
||||
UserData: userData,
|
||||
})
|
||||
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
log.Errorf("failed to flush mangle rule %s: %v", string(userData), err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nfRule
|
||||
}
|
||||
|
||||
func (m *AclManager) createDefaultChains() (err error) {
|
||||
@@ -697,8 +684,8 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":"
|
||||
func generatePeerRuleId(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||
rulesetID := ":" + string(proto) + ":"
|
||||
if sPort != nil {
|
||||
rulesetID += sPort.String()
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package nftables
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/google/nftables"
|
||||
@@ -19,13 +19,22 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// tableNameNetbird is the name of the table that is used for filtering by the Netbird client
|
||||
// tableNameNetbird is the default name of the table that is used for filtering by the Netbird client
|
||||
tableNameNetbird = "netbird"
|
||||
// envTableName is the environment variable to override the table name
|
||||
envTableName = "NB_NFTABLES_TABLE"
|
||||
|
||||
tableNameFilter = "filter"
|
||||
chainNameInput = "INPUT"
|
||||
)
|
||||
|
||||
func getTableName() string {
|
||||
if name := os.Getenv(envTableName); name != "" {
|
||||
return name
|
||||
}
|
||||
return tableNameNetbird
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
type iFaceMapper interface {
|
||||
Name() string
|
||||
@@ -44,16 +53,16 @@ type Manager struct {
|
||||
}
|
||||
|
||||
// Create nftables firewall manager
|
||||
func Create(wgIface iFaceMapper) (*Manager, error) {
|
||||
func Create(wgIface iFaceMapper, mtu uint16) (*Manager, error) {
|
||||
m := &Manager{
|
||||
rConn: &nftables.Conn{},
|
||||
wgIface: wgIface,
|
||||
}
|
||||
|
||||
workTable := &nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4}
|
||||
workTable := &nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4}
|
||||
|
||||
var err error
|
||||
m.router, err = newRouter(workTable, wgIface)
|
||||
m.router, err = newRouter(workTable, wgIface, mtu)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create router: %w", err)
|
||||
}
|
||||
@@ -93,6 +102,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
NameStr: m.wgIface.Name(),
|
||||
WGAddress: m.wgIface.Address(),
|
||||
UserspaceBind: m.wgIface.IsUserspaceBind(),
|
||||
MTU: m.router.mtu,
|
||||
},
|
||||
}); err != nil {
|
||||
log.Errorf("failed to update state: %v", err)
|
||||
@@ -197,44 +207,11 @@ func (m *Manager) AllowNetbird() error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
err := m.aclManager.createDefaultAllowRules()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create default allow rules: %v", err)
|
||||
if err := m.aclManager.createDefaultAllowRules(); err != nil {
|
||||
return fmt.Errorf("create default allow rules: %w", err)
|
||||
}
|
||||
|
||||
chains, err := m.rConn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list of chains: %w", err)
|
||||
}
|
||||
|
||||
var chain *nftables.Chain
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
chain = c
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if chain == nil {
|
||||
log.Debugf("chain INPUT not found. Skipping add allow netbird rule")
|
||||
return nil
|
||||
}
|
||||
|
||||
rules, err := m.rConn.GetRules(chain.Table, chain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rules for the INPUT chain: %v", err)
|
||||
}
|
||||
|
||||
if rule := m.detectAllowNetbirdRule(rules); rule != nil {
|
||||
log.Debugf("allow netbird rule already exists: %v", rule)
|
||||
return nil
|
||||
}
|
||||
|
||||
m.applyAllowNetbirdRules(chain)
|
||||
|
||||
err = m.rConn.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to flush allow input netbird rules: %v", err)
|
||||
if err := m.rConn.Flush(); err != nil {
|
||||
return fmt.Errorf("flush allow input netbird rules: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -250,10 +227,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if err := m.resetNetbirdInputRules(); err != nil {
|
||||
return fmt.Errorf("reset netbird input rules: %v", err)
|
||||
}
|
||||
|
||||
if err := m.router.Reset(); err != nil {
|
||||
return fmt.Errorf("reset router: %v", err)
|
||||
}
|
||||
@@ -273,49 +246,15 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) resetNetbirdInputRules() error {
|
||||
chains, err := m.rConn.ListChains()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %w", err)
|
||||
}
|
||||
|
||||
m.deleteNetbirdInputRules(chains)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) deleteNetbirdInputRules(chains []*nftables.Chain) {
|
||||
for _, c := range chains {
|
||||
if c.Table.Name == tableNameFilter && c.Name == chainNameInput {
|
||||
rules, err := m.rConn.GetRules(c.Table, c)
|
||||
if err != nil {
|
||||
log.Errorf("get rules for chain %q: %v", c.Name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
m.deleteMatchingRules(rules)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) deleteMatchingRules(rules []*nftables.Rule) {
|
||||
for _, r := range rules {
|
||||
if bytes.Equal(r.UserData, []byte(allowNetbirdInputRuleID)) {
|
||||
if err := m.rConn.DelRule(r); err != nil {
|
||||
log.Errorf("delete rule: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNetbirdTables() error {
|
||||
tables, err := m.rConn.ListTables()
|
||||
if err != nil {
|
||||
return fmt.Errorf("list tables: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
@@ -398,55 +337,18 @@ func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||
return nil, fmt.Errorf("list of tables: %w", err)
|
||||
}
|
||||
|
||||
tableName := getTableName()
|
||||
for _, t := range tables {
|
||||
if t.Name == tableNameNetbird {
|
||||
if t.Name == tableName {
|
||||
m.rConn.DelTable(t)
|
||||
}
|
||||
}
|
||||
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: tableNameNetbird, Family: nftables.TableFamilyIPv4})
|
||||
table := m.rConn.AddTable(&nftables.Table{Name: getTableName(), Family: nftables.TableFamilyIPv4})
|
||||
err = m.rConn.Flush()
|
||||
return table, err
|
||||
}
|
||||
|
||||
func (m *Manager) applyAllowNetbirdRules(chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: chain.Table,
|
||||
Chain: chain,
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(m.wgIface.Name()),
|
||||
},
|
||||
&expr.Verdict{
|
||||
Kind: expr.VerdictAccept,
|
||||
},
|
||||
},
|
||||
UserData: []byte(allowNetbirdInputRuleID),
|
||||
}
|
||||
_ = m.rConn.InsertRule(rule)
|
||||
}
|
||||
|
||||
func (m *Manager) detectAllowNetbirdRule(existedRules []*nftables.Rule) *nftables.Rule {
|
||||
ifName := ifname(m.wgIface.Name())
|
||||
for _, rule := range existedRules {
|
||||
if rule.Table.Name == tableNameFilter && rule.Chain.Name == chainNameInput {
|
||||
if len(rule.Exprs) < 4 {
|
||||
if e, ok := rule.Exprs[0].(*expr.Meta); !ok || e.Key != expr.MetaKeyIIFNAME {
|
||||
continue
|
||||
}
|
||||
if e, ok := rule.Exprs[1].(*expr.Cmp); !ok || e.Op != expr.CmpOpEq || !bytes.Equal(e.Data, ifName) {
|
||||
continue
|
||||
}
|
||||
return rule
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func insertReturnTrafficRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain) {
|
||||
rule := &nftables.Rule{
|
||||
Table: table,
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -56,7 +57,7 @@ func (i *iFaceMock) IsUserspaceBind() bool { return false }
|
||||
func TestNftablesManager(t *testing.T) {
|
||||
|
||||
// just check on the local interface
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
@@ -168,7 +169,7 @@ func TestNftablesManager(t *testing.T) {
|
||||
func TestNftablesManagerRuleOrder(t *testing.T) {
|
||||
// This test verifies rule insertion order in nftables peer ACLs
|
||||
// We add accept rule first, then deny rule to test ordering behavior
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
@@ -261,7 +262,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||
// just check on the local interface
|
||||
manager, err := Create(mock)
|
||||
manager, err := Create(mock, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, manager.Init(nil))
|
||||
time.Sleep(time.Second * 3)
|
||||
@@ -345,7 +346,7 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||
stdout, stderr := runIptablesSave(t)
|
||||
verifyIptablesOutput(t, stdout, stderr)
|
||||
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "failed to create manager")
|
||||
require.NoError(t, manager.Init(nil))
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/google/nftables/xt"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -32,12 +33,17 @@ const (
|
||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||
chainNameForward = "FORWARD"
|
||||
chainNameMangleForward = "netbird-mangle-forward"
|
||||
|
||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||
userDataAcceptInputRule = "inputaccept"
|
||||
|
||||
dnatSuffix = "_dnat"
|
||||
snatSuffix = "_snat"
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
const refreshRulesMapError = "refresh rules map: %w"
|
||||
@@ -63,9 +69,10 @@ type router struct {
|
||||
wgIface iFaceMapper
|
||||
ipFwdState *ipfwdstate.IPForwardingState
|
||||
legacyManagement bool
|
||||
mtu uint16
|
||||
}
|
||||
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
|
||||
r := &router{
|
||||
conn: &nftables.Conn{},
|
||||
workTable: workTable,
|
||||
@@ -73,6 +80,7 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
||||
rules: make(map[string]*nftables.Rule),
|
||||
wgIface: wgIface,
|
||||
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||
mtu: mtu,
|
||||
}
|
||||
|
||||
r.ipsetCounter = refcounter.New(
|
||||
@@ -96,8 +104,8 @@ func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error)
|
||||
func (r *router) init(workTable *nftables.Table) error {
|
||||
r.workTable = workTable
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from FORWARD chain: %s", err)
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
log.Errorf("failed to clean up rules from filter table: %s", err)
|
||||
}
|
||||
|
||||
if err := r.createContainers(); err != nil {
|
||||
@@ -111,15 +119,15 @@ func (r *router) init(workTable *nftables.Table) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset cleans existing nftables default forward rules from the system
|
||||
// Reset cleans existing nftables filter table rules from the system
|
||||
func (r *router) Reset() error {
|
||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||
r.ipsetCounter.Clear()
|
||||
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.removeAcceptForwardRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
|
||||
if err := r.removeAcceptFilterRules(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
|
||||
}
|
||||
|
||||
if err := r.removeNatPreroutingRules(); err != nil {
|
||||
@@ -220,11 +228,23 @@ func (r *router) createContainers() error {
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
r.chains[chainNameMangleForward] = r.conn.AddChain(&nftables.Chain{
|
||||
Name: chainNameMangleForward,
|
||||
Table: r.workTable,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
Priority: nftables.ChainPriorityMangle,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
})
|
||||
|
||||
// Add the single NAT rule that matches on mark
|
||||
if err := r.addPostroutingRules(); err != nil {
|
||||
return fmt.Errorf("add single nat rule: %v", err)
|
||||
}
|
||||
|
||||
if err := r.addMSSClampingRules(); err != nil {
|
||||
log.Errorf("failed to add MSS clamping rules: %s", err)
|
||||
}
|
||||
|
||||
if err := r.acceptForwardRules(); err != nil {
|
||||
log.Errorf("failed to add accept rules for the forward chain: %s", err)
|
||||
}
|
||||
@@ -745,6 +765,83 @@ func (r *router) addPostroutingRules() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// addMSSClampingRules adds MSS clamping rules to prevent fragmentation for forwarded traffic.
|
||||
// TODO: Add IPv6 support
|
||||
func (r *router) addMSSClampingRules() error {
|
||||
mss := r.mtu - ipTCPHeaderMinSize
|
||||
|
||||
exprsOut := []expr.Any{
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyOIFNAME,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: ifname(r.wgIface.Name()),
|
||||
},
|
||||
&expr.Meta{
|
||||
Key: expr.MetaKeyL4PROTO,
|
||||
Register: 1,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: []byte{unix.IPPROTO_TCP},
|
||||
},
|
||||
&expr.Payload{
|
||||
DestRegister: 1,
|
||||
Base: expr.PayloadBaseTransportHeader,
|
||||
Offset: 13,
|
||||
Len: 1,
|
||||
},
|
||||
&expr.Bitwise{
|
||||
DestRegister: 1,
|
||||
SourceRegister: 1,
|
||||
Len: 1,
|
||||
Mask: []byte{0x02},
|
||||
Xor: []byte{0x00},
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpNeq,
|
||||
Register: 1,
|
||||
Data: []byte{0x00},
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Exthdr{
|
||||
DestRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpGt,
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Immediate{
|
||||
Register: 1,
|
||||
Data: binaryutil.BigEndian.PutUint16(uint16(mss)),
|
||||
},
|
||||
&expr.Exthdr{
|
||||
SourceRegister: 1,
|
||||
Type: 2,
|
||||
Offset: 2,
|
||||
Len: 2,
|
||||
Op: expr.ExthdrOpTcpopt,
|
||||
},
|
||||
}
|
||||
|
||||
r.conn.AddRule(&nftables.Rule{
|
||||
Table: r.workTable,
|
||||
Chain: r.chains[chainNameMangleForward],
|
||||
Exprs: exprsOut,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||
@@ -840,6 +937,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
||||
// that our traffic is not dropped by existing rules there.
|
||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||
// This method also adds INPUT chain rules to allow traffic to the local interface.
|
||||
func (r *router) acceptForwardRules() error {
|
||||
if r.filterTable == nil {
|
||||
log.Debugf("table 'filter' not found for forward rules, skipping accept rules")
|
||||
@@ -849,7 +947,7 @@ func (r *router) acceptForwardRules() error {
|
||||
fw := "iptables"
|
||||
|
||||
defer func() {
|
||||
log.Debugf("Used %s to add accept forward rules", fw)
|
||||
log.Debugf("Used %s to add accept forward and input rules", fw)
|
||||
}()
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
@@ -859,22 +957,30 @@ func (r *router) acceptForwardRules() error {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
|
||||
fw = "nftables"
|
||||
return r.acceptForwardRulesNftables()
|
||||
return r.acceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.acceptForwardRulesIptables(ipt)
|
||||
return r.acceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
func (r *router) acceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.Insert("filter", chainNameForward, 1, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables forward rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables rule: %v", rule)
|
||||
log.Debugf("added iptables forward rule: %v", rule)
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.Insert("filter", chainNameInput, 1, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("add iptables input rule: %v", err))
|
||||
} else {
|
||||
log.Debugf("added iptables input rule: %v", inputRule)
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
@@ -886,10 +992,13 @@ func (r *router) getAcceptForwardRules() [][]string {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *router) acceptForwardRulesNftables() error {
|
||||
func (r *router) getAcceptInputRule() []string {
|
||||
return []string{"-i", r.wgIface.Name(), "-j", "ACCEPT"}
|
||||
}
|
||||
|
||||
func (r *router) acceptFilterRulesNftables() error {
|
||||
intf := ifname(r.wgIface.Name())
|
||||
|
||||
// Rule for incoming interface (iif) with counter
|
||||
iifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
@@ -922,11 +1031,10 @@ func (r *router) acceptForwardRulesNftables() error {
|
||||
},
|
||||
}
|
||||
|
||||
// Rule for outgoing interface (oif) with counter
|
||||
oifRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: "FORWARD",
|
||||
Name: chainNameForward,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookForward,
|
||||
@@ -935,35 +1043,60 @@ func (r *router) acceptForwardRulesNftables() error {
|
||||
Exprs: append(oifExprs, getEstablishedExprs(2)...),
|
||||
UserData: []byte(userDataAcceptForwardRuleOif),
|
||||
}
|
||||
|
||||
r.conn.InsertRule(oifRule)
|
||||
|
||||
inputRule := &nftables.Rule{
|
||||
Table: r.filterTable,
|
||||
Chain: &nftables.Chain{
|
||||
Name: chainNameInput,
|
||||
Table: r.filterTable,
|
||||
Type: nftables.ChainTypeFilter,
|
||||
Hooknum: nftables.ChainHookInput,
|
||||
Priority: nftables.ChainPriorityFilter,
|
||||
},
|
||||
Exprs: []expr.Any{
|
||||
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||
&expr.Cmp{
|
||||
Op: expr.CmpOpEq,
|
||||
Register: 1,
|
||||
Data: intf,
|
||||
},
|
||||
&expr.Counter{},
|
||||
&expr.Verdict{Kind: expr.VerdictAccept},
|
||||
},
|
||||
UserData: []byte(userDataAcceptInputRule),
|
||||
}
|
||||
r.conn.InsertRule(inputRule)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRules() error {
|
||||
func (r *router) removeAcceptFilterRules() error {
|
||||
if r.filterTable == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try iptables first and fallback to nftables if iptables is not available
|
||||
ipt, err := iptables.New()
|
||||
if err != nil {
|
||||
log.Warnf("Will use nftables to manipulate the filter table because iptables is not available: %v", err)
|
||||
return r.removeAcceptForwardRulesNftables()
|
||||
return r.removeAcceptFilterRulesNftables()
|
||||
}
|
||||
|
||||
return r.removeAcceptForwardRulesIptables(ipt)
|
||||
return r.removeAcceptFilterRulesIptables(ipt)
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
func (r *router) removeAcceptFilterRulesNftables() error {
|
||||
chains, err := r.conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list chains: %v", err)
|
||||
}
|
||||
|
||||
for _, chain := range chains {
|
||||
if chain.Table.Name != r.filterTable.Name || chain.Name != chainNameForward {
|
||||
if chain.Table.Name != r.filterTable.Name {
|
||||
continue
|
||||
}
|
||||
|
||||
if chain.Name != chainNameForward && chain.Name != chainNameInput {
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -974,7 +1107,8 @@ func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
|
||||
for _, rule := range rules {
|
||||
if bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleIif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) {
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptForwardRuleOif)) ||
|
||||
bytes.Equal(rule.UserData, []byte(userDataAcceptInputRule)) {
|
||||
if err := r.conn.DelRule(rule); err != nil {
|
||||
return fmt.Errorf("delete rule: %v", err)
|
||||
}
|
||||
@@ -989,14 +1123,20 @@ func (r *router) removeAcceptForwardRulesNftables() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error {
|
||||
func (r *router) removeAcceptFilterRulesIptables(ipt *iptables.IPTables) error {
|
||||
var merr *multierror.Error
|
||||
|
||||
for _, rule := range r.getAcceptForwardRules() {
|
||||
if err := ipt.DeleteIfExists("filter", chainNameForward, rule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables rule: %v", err))
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables forward rule: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
inputRule := r.getAcceptInputRule()
|
||||
if err := ipt.DeleteIfExists("filter", chainNameInput, inputRule...); err != nil {
|
||||
merr = multierror.Append(err, fmt.Errorf("remove iptables input rule: %v", err))
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/test"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -36,7 +37,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
||||
for _, testCase := range test.InsertRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
// need fw manager to init both acl mgr and router for all chains to be present
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
@@ -125,7 +126,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
||||
|
||||
for _, testCase := range test.RemoveRuleTestCases {
|
||||
t.Run(testCase.Name, func(t *testing.T) {
|
||||
manager, err := Create(ifaceMock)
|
||||
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
})
|
||||
@@ -197,7 +198,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
@@ -364,7 +365,7 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
||||
|
||||
defer deleteWorkTable()
|
||||
|
||||
r, err := newRouter(workTable, ifaceMock)
|
||||
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||
require.NoError(t, err, "Failed to create router")
|
||||
require.NoError(t, r.init(workTable))
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package nftables
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
|
||||
@@ -10,6 +11,7 @@ type InterfaceState struct {
|
||||
NameStr string `json:"name"`
|
||||
WGAddress wgaddr.Address `json:"wg_address"`
|
||||
UserspaceBind bool `json:"userspace_bind"`
|
||||
MTU uint16 `json:"mtu"`
|
||||
}
|
||||
|
||||
func (i *InterfaceState) Name() string {
|
||||
@@ -33,7 +35,11 @@ func (s *ShutdownState) Name() string {
|
||||
}
|
||||
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
nft, err := Create(s.InterfaceState)
|
||||
mtu := s.InterfaceState.MTU
|
||||
if mtu == 0 {
|
||||
mtu = iface.DefaultMTU
|
||||
}
|
||||
nft, err := Create(s.InterfaceState, mtu)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create nftables manager: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
@@ -27,7 +28,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
const layerTypeAll = 0
|
||||
const (
|
||||
layerTypeAll = 0
|
||||
|
||||
// ipTCPHeaderMinSize represents minimum IP (20) + TCP (20) header size for MSS calculation
|
||||
ipTCPHeaderMinSize = 40
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed.
|
||||
@@ -36,6 +42,9 @@ const (
|
||||
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
|
||||
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"
|
||||
|
||||
// EnvDisableMSSClamping disables TCP MSS clamping for forwarded traffic.
|
||||
EnvDisableMSSClamping = "NB_DISABLE_MSS_CLAMPING"
|
||||
|
||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||
|
||||
@@ -50,6 +59,12 @@ const (
|
||||
|
||||
var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
||||
|
||||
// serviceKey represents a protocol/port combination for netstack service registry
|
||||
type serviceKey struct {
|
||||
protocol gopacket.LayerType
|
||||
port uint16
|
||||
}
|
||||
|
||||
// RuleSet is a set of rules grouped by a string key
|
||||
type RuleSet map[string]PeerRule
|
||||
|
||||
@@ -113,6 +128,13 @@ type Manager struct {
|
||||
portDNATEnabled atomic.Bool
|
||||
portDNATRules []portDNATRule
|
||||
portDNATMutex sync.RWMutex
|
||||
|
||||
netstackServices map[serviceKey]struct{}
|
||||
netstackServiceMutex sync.RWMutex
|
||||
|
||||
mtu uint16
|
||||
mssClampValue uint16
|
||||
mssClampEnabled bool
|
||||
}
|
||||
|
||||
// decoder for packages
|
||||
@@ -131,16 +153,16 @@ type decoder struct {
|
||||
}
|
||||
|
||||
// Create userspace firewall manager constructor
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
return create(iface, nil, disableServerRoutes, flowLogger, mtu)
|
||||
}
|
||||
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
if nativeFirewall == nil {
|
||||
return nil, errors.New("native firewall is nil")
|
||||
}
|
||||
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger, mtu)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -148,8 +170,8 @@ func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
func parseCreateEnv() (bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding bool
|
||||
func parseCreateEnv() (bool, bool, bool) {
|
||||
var disableConntrack, enableLocalForwarding, disableMSSClamping bool
|
||||
var err error
|
||||
if val := os.Getenv(EnvDisableConntrack); val != "" {
|
||||
disableConntrack, err = strconv.ParseBool(val)
|
||||
@@ -168,12 +190,18 @@ func parseCreateEnv() (bool, bool) {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||
}
|
||||
}
|
||||
if val := os.Getenv(EnvDisableMSSClamping); val != "" {
|
||||
disableMSSClamping, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvDisableMSSClamping, err)
|
||||
}
|
||||
}
|
||||
|
||||
return disableConntrack, enableLocalForwarding
|
||||
return disableConntrack, enableLocalForwarding, disableMSSClamping
|
||||
}
|
||||
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger, mtu uint16) (*Manager, error) {
|
||||
disableConntrack, enableLocalForwarding, disableMSSClamping := parseCreateEnv()
|
||||
|
||||
m := &Manager{
|
||||
decoders: sync.Pool{
|
||||
@@ -203,13 +231,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
localForwarding: enableLocalForwarding,
|
||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||
portDNATRules: []portDNATRule{},
|
||||
netstackServices: make(map[serviceKey]struct{}),
|
||||
mtu: mtu,
|
||||
}
|
||||
m.routingEnabled.Store(false)
|
||||
|
||||
if !disableMSSClamping {
|
||||
m.mssClampEnabled = true
|
||||
m.mssClampValue = mtu - ipTCPHeaderMinSize
|
||||
}
|
||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||
}
|
||||
|
||||
if disableConntrack {
|
||||
log.Info("conntrack is disabled")
|
||||
} else {
|
||||
@@ -217,14 +250,11 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||
}
|
||||
|
||||
// netstack needs the forwarder for local traffic
|
||||
if m.netstack && m.localForwarding {
|
||||
if err := m.initForwarder(); err != nil {
|
||||
log.Errorf("failed to initialize forwarder: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := iface.SetFilter(m); err != nil {
|
||||
return nil, fmt.Errorf("set filter: %w", err)
|
||||
}
|
||||
@@ -327,7 +357,7 @@ func (m *Manager) initForwarder() error {
|
||||
return errors.New("forwarding not supported")
|
||||
}
|
||||
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack, m.mtu)
|
||||
if err != nil {
|
||||
m.routingEnabled.Store(false)
|
||||
return fmt.Errorf("create forwarder: %w", err)
|
||||
@@ -633,8 +663,17 @@ func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeUDP:
|
||||
if m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||
return true
|
||||
}
|
||||
case layers.LayerTypeTCP:
|
||||
// Clamp MSS on all TCP SYN packets, including those from local IPs.
|
||||
// SNATed routed traffic may appear as local IP but still requires clamping.
|
||||
if m.mssClampEnabled {
|
||||
m.clampTCPMSS(packetData, d)
|
||||
}
|
||||
}
|
||||
|
||||
m.trackOutbound(d, srcIP, dstIP, packetData, size)
|
||||
@@ -681,6 +720,97 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||
return flags
|
||||
}
|
||||
|
||||
// clampTCPMSS clamps the TCP MSS option in SYN and SYN-ACK packets to prevent fragmentation.
|
||||
// Both sides advertise their MSS during connection establishment, so we need to clamp both.
|
||||
func (m *Manager) clampTCPMSS(packetData []byte, d *decoder) bool {
|
||||
if !d.tcp.SYN {
|
||||
return false
|
||||
}
|
||||
if len(d.tcp.Options) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
mssOptionIndex := -1
|
||||
var currentMSS uint16
|
||||
for i, opt := range d.tcp.Options {
|
||||
if opt.OptionType == layers.TCPOptionKindMSS && len(opt.OptionData) == 2 {
|
||||
currentMSS = binary.BigEndian.Uint16(opt.OptionData)
|
||||
if currentMSS > m.mssClampValue {
|
||||
mssOptionIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if mssOptionIndex == -1 {
|
||||
return false
|
||||
}
|
||||
|
||||
ipHeaderSize := int(d.ip4.IHL) * 4
|
||||
if ipHeaderSize < 20 {
|
||||
return false
|
||||
}
|
||||
|
||||
if !m.updateMSSOption(packetData, d, mssOptionIndex, ipHeaderSize) {
|
||||
return false
|
||||
}
|
||||
|
||||
m.logger.Trace2("Clamped TCP MSS from %d to %d", currentMSS, m.mssClampValue)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) updateMSSOption(packetData []byte, d *decoder, mssOptionIndex, ipHeaderSize int) bool {
|
||||
tcpHeaderStart := ipHeaderSize
|
||||
tcpOptionsStart := tcpHeaderStart + 20
|
||||
|
||||
optOffset := tcpOptionsStart
|
||||
for j := 0; j < mssOptionIndex; j++ {
|
||||
switch d.tcp.Options[j].OptionType {
|
||||
case layers.TCPOptionKindEndList, layers.TCPOptionKindNop:
|
||||
optOffset++
|
||||
default:
|
||||
optOffset += 2 + len(d.tcp.Options[j].OptionData)
|
||||
}
|
||||
}
|
||||
|
||||
mssValueOffset := optOffset + 2
|
||||
binary.BigEndian.PutUint16(packetData[mssValueOffset:mssValueOffset+2], m.mssClampValue)
|
||||
|
||||
m.recalculateTCPChecksum(packetData, d, tcpHeaderStart)
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) recalculateTCPChecksum(packetData []byte, d *decoder, tcpHeaderStart int) {
|
||||
tcpLayer := packetData[tcpHeaderStart:]
|
||||
tcpLength := len(packetData) - tcpHeaderStart
|
||||
|
||||
tcpLayer[16] = 0
|
||||
tcpLayer[17] = 0
|
||||
|
||||
var pseudoSum uint32
|
||||
pseudoSum += uint32(d.ip4.SrcIP[0])<<8 | uint32(d.ip4.SrcIP[1])
|
||||
pseudoSum += uint32(d.ip4.SrcIP[2])<<8 | uint32(d.ip4.SrcIP[3])
|
||||
pseudoSum += uint32(d.ip4.DstIP[0])<<8 | uint32(d.ip4.DstIP[1])
|
||||
pseudoSum += uint32(d.ip4.DstIP[2])<<8 | uint32(d.ip4.DstIP[3])
|
||||
pseudoSum += uint32(d.ip4.Protocol)
|
||||
pseudoSum += uint32(tcpLength)
|
||||
|
||||
var sum uint32 = pseudoSum
|
||||
for i := 0; i < tcpLength-1; i += 2 {
|
||||
sum += uint32(tcpLayer[i])<<8 | uint32(tcpLayer[i+1])
|
||||
}
|
||||
if tcpLength%2 == 1 {
|
||||
sum += uint32(tcpLayer[tcpLength-1]) << 8
|
||||
}
|
||||
|
||||
for sum > 0xFFFF {
|
||||
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||
}
|
||||
|
||||
checksum := ^uint16(sum)
|
||||
binary.BigEndian.PutUint16(tcpLayer[16:18], checksum)
|
||||
}
|
||||
|
||||
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) {
|
||||
transport := d.decoded[1]
|
||||
switch transport {
|
||||
@@ -838,9 +968,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
||||
return true
|
||||
}
|
||||
|
||||
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||
if m.shouldForward(d, dstIP) {
|
||||
return m.handleForwardedLocalTraffic(packetData)
|
||||
}
|
||||
|
||||
@@ -1274,3 +1402,86 @@ func (m *Manager) DisableRouting() error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterNetstackService registers a service as listening on the netstack for the given protocol and port
|
||||
func (m *Manager) RegisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
m.netstackServices[key] = struct{}{}
|
||||
m.logger.Debug3("RegisterNetstackService: registered %s:%d (layerType=%s)", protocol, port, layerType)
|
||||
m.logger.Debug1("RegisterNetstackService: current registry size: %d", len(m.netstackServices))
|
||||
}
|
||||
|
||||
// UnregisterNetstackService removes a service from the netstack registry
|
||||
func (m *Manager) UnregisterNetstackService(protocol nftypes.Protocol, port uint16) {
|
||||
m.netstackServiceMutex.Lock()
|
||||
defer m.netstackServiceMutex.Unlock()
|
||||
layerType := m.protocolToLayerType(protocol)
|
||||
key := serviceKey{protocol: layerType, port: port}
|
||||
delete(m.netstackServices, key)
|
||||
m.logger.Debug2("Unregistered netstack service on protocol %s port %d", protocol, port)
|
||||
}
|
||||
|
||||
// protocolToLayerType converts nftypes.Protocol to gopacket.LayerType for internal use
|
||||
func (m *Manager) protocolToLayerType(protocol nftypes.Protocol) gopacket.LayerType {
|
||||
switch protocol {
|
||||
case nftypes.TCP:
|
||||
return layers.LayerTypeTCP
|
||||
case nftypes.UDP:
|
||||
return layers.LayerTypeUDP
|
||||
case nftypes.ICMP:
|
||||
return layers.LayerTypeICMPv4
|
||||
default:
|
||||
return gopacket.LayerType(0) // Invalid/unknown
|
||||
}
|
||||
}
|
||||
|
||||
// shouldForward determines if a packet should be forwarded to the forwarder.
|
||||
// The forwarder handles routing packets to the native OS network stack.
|
||||
// Returns true if packet should go to the forwarder, false if it should go to netstack listeners or the native stack directly.
|
||||
func (m *Manager) shouldForward(d *decoder, dstIP netip.Addr) bool {
|
||||
// not enabled, never forward
|
||||
if !m.localForwarding {
|
||||
return false
|
||||
}
|
||||
|
||||
// netstack always needs to forward because it's lacking a native interface
|
||||
// exception for registered netstack services, those should go to netstack listeners
|
||||
if m.netstack {
|
||||
return !m.hasMatchingNetstackService(d)
|
||||
}
|
||||
|
||||
// traffic to our other local interfaces (not NetBird IP) - always forward
|
||||
if dstIP != m.wgIface.Address().IP {
|
||||
return true
|
||||
}
|
||||
|
||||
// traffic to our NetBird IP, not netstack mode - send to netstack listeners
|
||||
return false
|
||||
}
|
||||
|
||||
// hasMatchingNetstackService checks if there's a registered netstack service for this packet
|
||||
func (m *Manager) hasMatchingNetstackService(d *decoder) bool {
|
||||
if len(d.decoded) < 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
var dstPort uint16
|
||||
switch d.decoded[1] {
|
||||
case layers.LayerTypeTCP:
|
||||
dstPort = uint16(d.tcp.DstPort)
|
||||
case layers.LayerTypeUDP:
|
||||
dstPort = uint16(d.udp.DstPort)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
key := serviceKey{protocol: d.decoded[1], port: dstPort}
|
||||
m.netstackServiceMutex.RLock()
|
||||
_, exists := m.netstackServices[key]
|
||||
m.netstackServiceMutex.RUnlock()
|
||||
|
||||
return exists
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -169,7 +170,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
||||
// Create manager and basic setup
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -209,7 +210,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -252,7 +253,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -410,7 +411,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -537,7 +538,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -620,7 +621,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -731,7 +732,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -811,7 +812,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
|
||||
manager, _ := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
defer b.Cleanup(func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
})
|
||||
@@ -896,38 +897,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
// generateTCPPacketWithFlags creates a TCP packet with specific flags
|
||||
func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort, flags uint16) []byte {
|
||||
b.Helper()
|
||||
|
||||
ipv4 := &layers.IPv4{
|
||||
TTL: 64,
|
||||
Version: 4,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
}
|
||||
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
}
|
||||
|
||||
// Set TCP flags
|
||||
tcp.SYN = (flags & uint16(conntrack.TCPSyn)) != 0
|
||||
tcp.ACK = (flags & uint16(conntrack.TCPAck)) != 0
|
||||
tcp.PSH = (flags & uint16(conntrack.TCPPush)) != 0
|
||||
tcp.RST = (flags & uint16(conntrack.TCPRst)) != 0
|
||||
tcp.FIN = (flags & uint16(conntrack.TCPFin)) != 0
|
||||
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test")))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func BenchmarkRouteACLs(b *testing.B) {
|
||||
manager := setupRoutedManager(b, "10.10.0.100/16")
|
||||
|
||||
@@ -990,3 +959,231 @@ func BenchmarkRouteACLs(b *testing.B) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClamping benchmarks the MSS clamping impact on filterOutbound.
|
||||
// This shows the overhead difference between the common case (non-SYN packets, fast path)
|
||||
// and the rare case (SYN packets that need clamping, expensive path).
|
||||
func BenchmarkMSSClamping(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
description string
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
frequency string
|
||||
}{
|
||||
{
|
||||
name: "syn_needs_clamp",
|
||||
description: "SYN packet needing MSS clamping",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
frequency: "~0.1% of traffic - EXPENSIVE",
|
||||
},
|
||||
{
|
||||
name: "syn_no_clamp_needed",
|
||||
description: "SYN packet with already-small MSS",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1200)
|
||||
},
|
||||
frequency: "~0.05% of traffic",
|
||||
},
|
||||
{
|
||||
name: "tcp_ack",
|
||||
description: "Non-SYN TCP packet (ACK, data transfer)",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
frequency: "~60-70% of traffic - FAST PATH",
|
||||
},
|
||||
{
|
||||
name: "tcp_psh_ack",
|
||||
description: "TCP data packet (PSH+ACK)",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPPush|conntrack.TCPAck))
|
||||
},
|
||||
frequency: "~10-20% of traffic - FAST PATH",
|
||||
},
|
||||
{
|
||||
name: "udp",
|
||||
description: "UDP packet",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||
},
|
||||
frequency: "~20-30% of traffic - FAST PATH",
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClampingOverhead compares overhead of MSS clamping enabled vs disabled
|
||||
// for the common case (non-SYN TCP packets).
|
||||
func BenchmarkMSSClampingOverhead(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
enabled bool
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
}{
|
||||
{
|
||||
name: "disabled_tcp_ack",
|
||||
enabled: false,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled_tcp_ack",
|
||||
enabled: true,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "disabled_syn_needs_clamp",
|
||||
enabled: false,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enabled_syn_needs_clamp",
|
||||
enabled: true,
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = sc.enabled
|
||||
if sc.enabled {
|
||||
manager.mssClampValue = 1240
|
||||
}
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkMSSClampingMemory measures memory allocations for common vs rare cases
|
||||
func BenchmarkMSSClampingMemory(b *testing.B) {
|
||||
scenarios := []struct {
|
||||
name string
|
||||
genPacket func(*testing.B, net.IP, net.IP) []byte
|
||||
}{
|
||||
{
|
||||
name: "tcp_ack_fast_path",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateTCPPacketWithFlags(b, src, dst, 12345, 80, uint16(conntrack.TCPAck))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "syn_needs_clamp",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generateSYNPacketWithMSS(b, src, dst, 12345, 80, 1460)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "udp_fast_path",
|
||||
genPacket: func(b *testing.B, src, dst net.IP) []byte {
|
||||
return generatePacket(b, src, dst, 12345, 80, layers.IPProtocolUDP)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, sc := range scenarios {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
}()
|
||||
|
||||
manager.mssClampEnabled = true
|
||||
manager.mssClampValue = 1240
|
||||
|
||||
srcIP := net.ParseIP("100.64.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
packet := sc.genPacket(b, srcIP, dstIP)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateSYNPacketNoMSS(b *testing.B, srcIP, dstIP net.IP, srcPort, dstPort uint16) []byte {
|
||||
b.Helper()
|
||||
|
||||
ip := &layers.IPv4{
|
||||
Version: 4,
|
||||
IHL: 5,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcp := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
Seq: 1000,
|
||||
Window: 65535,
|
||||
}
|
||||
|
||||
require.NoError(b, tcp.SetNetworkLayerForChecksum(ip))
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{
|
||||
FixLengths: true,
|
||||
ComputeChecksums: true,
|
||||
}
|
||||
|
||||
require.NoError(b, gopacket.SerializeLayers(buf, opts, ip, tcp, gopacket.Payload([]byte{})))
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
@@ -31,7 +32,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, manager)
|
||||
|
||||
@@ -616,7 +617,7 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(tb, err)
|
||||
require.NoError(tb, manager.EnableRouting())
|
||||
require.NotNil(tb, manager)
|
||||
@@ -1462,7 +1463,7 @@ func TestRouteACLSet(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package uspfilter
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
@@ -17,6 +18,7 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||
nbiface "github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -66,7 +68,7 @@ func TestManagerCreate(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -86,7 +88,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -119,7 +121,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -215,7 +217,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||
@@ -265,7 +267,7 @@ func TestManagerReset(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -304,7 +306,7 @@ func TestNotMatchByIP(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Errorf("failed to create Manager: %v", err)
|
||||
return
|
||||
@@ -367,7 +369,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
}
|
||||
|
||||
// creating manager instance
|
||||
manager, err := Create(iface, false, flowLogger)
|
||||
manager, err := Create(iface, false, flowLogger, nbiface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create Manager: %s", err)
|
||||
}
|
||||
@@ -413,7 +415,7 @@ func TestRemovePacketHook(t *testing.T) {
|
||||
func TestProcessOutgoingHooks(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close()
|
||||
@@ -495,7 +497,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
time.Sleep(time.Second)
|
||||
|
||||
@@ -522,7 +524,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
manager.udpTracker.Close() // Close the existing tracker
|
||||
@@ -729,7 +731,7 @@ func TestUpdateSetMerge(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -815,7 +817,7 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger)
|
||||
manager, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -923,3 +925,192 @@ func TestUpdateSetDeduplication(t *testing.T) {
|
||||
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSSClamping(t *testing.T) {
|
||||
ifaceMock := &IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
AddressFunc: func() wgaddr.Address {
|
||||
return wgaddr.Address{
|
||||
IP: netip.MustParseAddr("100.10.0.100"),
|
||||
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
manager, err := Create(ifaceMock, false, flowLogger, 1280)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
}()
|
||||
|
||||
require.True(t, manager.mssClampEnabled, "MSS clamping should be enabled by default")
|
||||
expectedMSSValue := uint16(1280 - ipTCPHeaderMinSize)
|
||||
require.Equal(t, expectedMSSValue, manager.mssClampValue, "MSS clamp value should be MTU - 40")
|
||||
|
||||
err = manager.UpdateLocalIPs()
|
||||
require.NoError(t, err)
|
||||
|
||||
srcIP := net.ParseIP("100.10.0.2")
|
||||
dstIP := net.ParseIP("8.8.8.8")
|
||||
|
||||
t.Run("SYN packet with high MSS gets clamped", func(t *testing.T) {
|
||||
highMSS := uint16(1460)
|
||||
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
require.Equal(t, uint8(layers.TCPOptionKindMSS), uint8(d.tcp.Options[0].OptionType))
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS should be clamped to MTU - 40")
|
||||
})
|
||||
|
||||
t.Run("SYN packet with low MSS unchanged", func(t *testing.T) {
|
||||
lowMSS := uint16(1200)
|
||||
packet := generateSYNPacketWithMSS(t, srcIP, dstIP, 12345, 80, lowMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, lowMSS, actualMSS, "Low MSS should not be modified")
|
||||
})
|
||||
|
||||
t.Run("SYN-ACK packet gets clamped", func(t *testing.T) {
|
||||
highMSS := uint16(1460)
|
||||
packet := generateSYNACKPacketWithMSS(t, srcIP, dstIP, 12345, 80, highMSS)
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Len(t, d.tcp.Options, 1, "Should have MSS option")
|
||||
actualMSS := binary.BigEndian.Uint16(d.tcp.Options[0].OptionData)
|
||||
require.Equal(t, expectedMSSValue, actualMSS, "MSS in SYN-ACK should be clamped")
|
||||
})
|
||||
|
||||
t.Run("Non-SYN packet unchanged", func(t *testing.T) {
|
||||
packet := generateTCPPacketWithFlags(t, srcIP, dstIP, 12345, 80, uint16(conntrack.TCPAck))
|
||||
|
||||
manager.filterOutbound(packet, len(packet))
|
||||
|
||||
d := parsePacket(t, packet)
|
||||
require.Empty(t, d.tcp.Options, "ACK packet should have no options")
|
||||
})
|
||||
}
|
||||
|
||||
func generateSYNPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
Window: 65535,
|
||||
Options: []layers.TCPOption{
|
||||
{
|
||||
OptionType: layers.TCPOptionKindMSS,
|
||||
OptionLength: 4,
|
||||
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||
},
|
||||
},
|
||||
}
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func generateSYNACKPacketWithMSS(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, mss uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
SYN: true,
|
||||
ACK: true,
|
||||
Window: 65535,
|
||||
Options: []layers.TCPOption{
|
||||
{
|
||||
OptionType: layers.TCPOptionKindMSS,
|
||||
OptionLength: 4,
|
||||
OptionData: binary.BigEndian.AppendUint16(nil, mss),
|
||||
},
|
||||
},
|
||||
}
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func generateTCPPacketWithFlags(tb testing.TB, srcIP, dstIP net.IP, srcPort, dstPort uint16, flags uint16) []byte {
|
||||
tb.Helper()
|
||||
|
||||
ipLayer := &layers.IPv4{
|
||||
Version: 4,
|
||||
TTL: 64,
|
||||
Protocol: layers.IPProtocolTCP,
|
||||
SrcIP: srcIP,
|
||||
DstIP: dstIP,
|
||||
}
|
||||
|
||||
tcpLayer := &layers.TCP{
|
||||
SrcPort: layers.TCPPort(srcPort),
|
||||
DstPort: layers.TCPPort(dstPort),
|
||||
Window: 65535,
|
||||
}
|
||||
|
||||
if flags&uint16(conntrack.TCPSyn) != 0 {
|
||||
tcpLayer.SYN = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPAck) != 0 {
|
||||
tcpLayer.ACK = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPFin) != 0 {
|
||||
tcpLayer.FIN = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPRst) != 0 {
|
||||
tcpLayer.RST = true
|
||||
}
|
||||
if flags&uint16(conntrack.TCPPush) != 0 {
|
||||
tcpLayer.PSH = true
|
||||
}
|
||||
|
||||
err := tcpLayer.SetNetworkLayerForChecksum(ipLayer)
|
||||
require.NoError(tb, err)
|
||||
|
||||
buf := gopacket.NewSerializeBuffer()
|
||||
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||
err = gopacket.SerializeLayers(buf, opts, ipLayer, tcpLayer, gopacket.Payload([]byte{}))
|
||||
require.NoError(tb, err)
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ type Forwarder struct {
|
||||
netstack bool
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
s := stack.New(stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{
|
||||
@@ -56,10 +56,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
HandleLocal: false,
|
||||
})
|
||||
|
||||
mtu, err := iface.GetDevice().MTU()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get MTU: %w", err)
|
||||
}
|
||||
nicID := tcpip.NICID(1)
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
@@ -68,7 +64,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
}
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
}
|
||||
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
|
||||
@@ -49,7 +49,7 @@ type idleConn struct {
|
||||
conn *udpPacketConn
|
||||
}
|
||||
|
||||
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
func newUDPForwarder(mtu uint16, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &udpForwarder{
|
||||
logger: logger,
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -65,7 +66,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -125,7 +126,7 @@ func BenchmarkDNATTranslation(b *testing.B) {
|
||||
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -197,7 +198,7 @@ func BenchmarkDNATScaling(b *testing.B) {
|
||||
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -309,7 +310,7 @@ func BenchmarkChecksumUpdate(b *testing.B) {
|
||||
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
@@ -472,7 +473,7 @@ func BenchmarkPortDNAT(b *testing.B) {
|
||||
b.Run(sc.name, func(b *testing.B) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(b, err)
|
||||
defer func() {
|
||||
require.NoError(b, manager.Close(nil))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
)
|
||||
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -100,7 +101,7 @@ func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||
func TestDNATMappingManagement(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -148,7 +149,7 @@ func TestDNATMappingManagement(t *testing.T) {
|
||||
func TestInboundPortDNAT(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
@@ -198,7 +199,7 @@ func TestInboundPortDNAT(t *testing.T) {
|
||||
func TestInboundPortDNATNegative(t *testing.T) {
|
||||
manager, err := Create(&IFaceMock{
|
||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||
}, false, flowLogger)
|
||||
}, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, manager.Close(nil))
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
)
|
||||
@@ -44,7 +45,7 @@ func TestTracePacket(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
m, err := Create(ifaceMock, false, flowLogger)
|
||||
m, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
|
||||
if !statefulMode {
|
||||
|
||||
@@ -4,12 +4,15 @@ import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/connectivity"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/keepalive"
|
||||
@@ -17,6 +20,9 @@ import (
|
||||
"github.com/netbirdio/netbird/util/embeddedroots"
|
||||
)
|
||||
|
||||
// ErrConnectionShutdown indicates that the connection entered shutdown state before becoming ready
|
||||
var ErrConnectionShutdown = errors.New("connection shutdown before ready")
|
||||
|
||||
// Backoff returns a backoff configuration for gRPC calls
|
||||
func Backoff(ctx context.Context) backoff.BackOff {
|
||||
b := backoff.NewExponentialBackOff()
|
||||
@@ -25,6 +31,26 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
return backoff.WithContext(b, ctx)
|
||||
}
|
||||
|
||||
// waitForConnectionReady blocks until the connection becomes ready or fails.
|
||||
// Returns an error if the connection times out, is cancelled, or enters shutdown state.
|
||||
func waitForConnectionReady(ctx context.Context, conn *grpc.ClientConn) error {
|
||||
conn.Connect()
|
||||
|
||||
state := conn.GetState()
|
||||
for state != connectivity.Ready && state != connectivity.Shutdown {
|
||||
if !conn.WaitForStateChange(ctx, state) {
|
||||
return fmt.Errorf("wait state change from %s: %w", state, ctx.Err())
|
||||
}
|
||||
state = conn.GetState()
|
||||
}
|
||||
|
||||
if state == connectivity.Shutdown {
|
||||
return ErrConnectionShutdown
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
@@ -42,22 +68,24 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
}))
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
transportOption,
|
||||
WithCustomDialer(tlsEnabled, component),
|
||||
grpc.WithBlock(),
|
||||
grpc.WithKeepaliveParams(keepalive.ClientParameters{
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("DialContext error: %v", err)
|
||||
return nil, fmt.Errorf("new client: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := waitForConnectionReady(ctx, conn); err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
func WithCustomDialer(_ bool, _ string) grpc.DialOption {
|
||||
return grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
|
||||
if runtime.GOOS == "linux" {
|
||||
currentUser, err := user.Current()
|
||||
@@ -36,7 +36,6 @@ func WithCustomDialer(tlsEnabled bool, component string) grpc.DialOption {
|
||||
|
||||
conn, err := nbnet.NewDialer().DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to dial: %s", err)
|
||||
return nil, fmt.Errorf("nbnet.NewDialer().DialContext: %w", err)
|
||||
}
|
||||
return conn, nil
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
@@ -52,7 +53,7 @@ func TestDefaultManager(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
@@ -170,7 +171,7 @@ func TestDefaultManagerStateless(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
@@ -321,7 +322,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
}).AnyTimes()
|
||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = fw.Close(nil)
|
||||
|
||||
@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
|
||||
deviceCode.VerificationURIComplete = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
if d.providerConfig.LoginHint != "" {
|
||||
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
|
||||
if deviceCode.VerificationURI != "" {
|
||||
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
|
||||
}
|
||||
}
|
||||
|
||||
return deviceCode, err
|
||||
}
|
||||
|
||||
func appendLoginHint(uri, loginHint string) string {
|
||||
if uri == "" || loginHint == "" {
|
||||
return uri
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse verification URI for login_hint: %v", err)
|
||||
return uri
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
query.Set("login_hint", loginHint)
|
||||
parsedURL.RawQuery = query.Encode()
|
||||
|
||||
return parsedURL.String()
|
||||
}
|
||||
|
||||
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
|
||||
form := url.Values{}
|
||||
form.Add("client_id", d.providerConfig.ClientID)
|
||||
|
||||
@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
|
||||
// and if that also fails, the authentication process is deemed unsuccessful
|
||||
//
|
||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
|
||||
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config)
|
||||
pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
|
||||
if err != nil {
|
||||
// fallback to device code flow
|
||||
log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
|
||||
log.Debug("falling back to device code flow")
|
||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||
return authenticateWithDeviceCodeFlow(ctx, config, hint)
|
||||
}
|
||||
return pkceFlow, nil
|
||||
}
|
||||
|
||||
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
|
||||
}
|
||||
|
||||
pkceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) {
|
||||
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
|
||||
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
|
||||
if err != nil {
|
||||
switch s, ok := gstatus.FromError(err); {
|
||||
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
|
||||
}
|
||||
}
|
||||
|
||||
deviceFlowInfo.ProviderConfig.LoginHint = hint
|
||||
|
||||
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
|
||||
}
|
||||
|
||||
@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
||||
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||
}
|
||||
}
|
||||
if p.providerConfig.LoginHint != "" {
|
||||
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
|
||||
}
|
||||
|
||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -34,7 +35,6 @@ import (
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
@@ -289,15 +289,18 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
if c.engine != nil && c.engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", c.engine.wgInterface.Name())
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
engine := c.engine
|
||||
c.engine = nil
|
||||
c.engineMutex.Unlock()
|
||||
|
||||
if engine != nil && engine.wgInterface != nil {
|
||||
log.Infof("ensuring %s is removed, Netbird engine context cancelled", engine.wgInterface.Name())
|
||||
if err := engine.Stop(); err != nil {
|
||||
log.Errorf("Failed to stop engine: %v", err)
|
||||
}
|
||||
c.engine = nil
|
||||
}
|
||||
c.engineMutex.Unlock()
|
||||
c.statusRecorder.ClientTeardown()
|
||||
|
||||
backOff.Reset()
|
||||
@@ -382,19 +385,12 @@ func (c *ConnectClient) Status() StatusType {
|
||||
}
|
||||
|
||||
func (c *ConnectClient) Stop() error {
|
||||
if c == nil {
|
||||
return nil
|
||||
engine := c.Engine()
|
||||
if engine != nil {
|
||||
if err := engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
}
|
||||
c.engineMutex.Lock()
|
||||
defer c.engineMutex.Unlock()
|
||||
|
||||
if c.engine == nil {
|
||||
return nil
|
||||
}
|
||||
if err := c.engine.Stop(); err != nil {
|
||||
return fmt.Errorf("stop engine: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ interfaces.txt: Anonymized network interface information, if --system-info flag
|
||||
ip_rules.txt: Detailed IP routing rules in tabular format including priority, source, destination, interfaces, table, and action information (Linux only), if --system-info flag was provided.
|
||||
iptables.txt: Anonymized iptables rules with packet counters, if --system-info flag was provided.
|
||||
nftables.txt: Anonymized nftables rules with packet counters, if --system-info flag was provided.
|
||||
resolv.conf: DNS resolver configuration from /etc/resolv.conf (Unix systems only), if --system-info flag was provided.
|
||||
scutil_dns.txt: DNS configuration from scutil --dns (macOS only), if --system-info flag was provided.
|
||||
resolved_domains.txt: Anonymized resolved domain IP addresses from the status recorder.
|
||||
config.txt: Anonymized configuration information of the NetBird client.
|
||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||
@@ -184,6 +186,20 @@ The ip_rules.txt file contains detailed IP routing rule information:
|
||||
The table format provides comprehensive visibility into the IP routing decision process, including how traffic is directed to different routing tables based on various criteria. This is valuable for troubleshooting advanced routing configurations and policy-based routing.
|
||||
|
||||
For anonymized rules, IP addresses and prefixes are replaced as described above. Interface names are anonymized using string anonymization. Table names, actions, and other non-sensitive information remain unchanged.
|
||||
|
||||
DNS Configuration
|
||||
The debug bundle includes platform-specific DNS configuration files:
|
||||
|
||||
resolv.conf (Unix systems):
|
||||
- Contains DNS resolver configuration from /etc/resolv.conf
|
||||
- Includes nameserver entries, search domains, and resolver options
|
||||
- All IP addresses and domain names are anonymized following the same rules as other files
|
||||
|
||||
scutil_dns.txt (macOS only):
|
||||
- Contains detailed DNS configuration from scutil --dns
|
||||
- Shows DNS configuration for all network interfaces
|
||||
- Includes search domains, nameservers, and DNS resolver settings
|
||||
- All IP addresses and domain names are anonymized
|
||||
`
|
||||
|
||||
const (
|
||||
@@ -357,6 +373,10 @@ func (g *BundleGenerator) addSystemInfo() {
|
||||
if err := g.addFirewallRules(); err != nil {
|
||||
log.Errorf("failed to add firewall rules to debug bundle: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addDNSInfo(); err != nil {
|
||||
log.Errorf("failed to add DNS info to debug bundle: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addReadme() error {
|
||||
|
||||
53
client/internal/debug/debug_darwin.go
Normal file
53
client/internal/debug/debug_darwin.go
Normal file
@@ -0,0 +1,53 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addScutilDNS(); err != nil {
|
||||
log.Errorf("failed to add scutil DNS output: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addScutilDNS() error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "scutil", "--dns")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("execute scutil --dns: %w", err)
|
||||
}
|
||||
|
||||
if len(bytes.TrimSpace(output)) == 0 {
|
||||
return fmt.Errorf("no scutil DNS output")
|
||||
}
|
||||
|
||||
content := string(output)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "scutil_dns.txt"); err != nil {
|
||||
return fmt.Errorf("add scutil DNS output to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -5,3 +5,7 @@ package debug
|
||||
func (g *BundleGenerator) addRoutes() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
16
client/internal/debug/debug_nondarwin.go
Normal file
16
client/internal/debug/debug_nondarwin.go
Normal file
@@ -0,0 +1,16 @@
|
||||
//go:build unix && !darwin && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// addDNSInfo collects and adds DNS configuration information to the archive
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
if err := g.addResolvConf(); err != nil {
|
||||
log.Errorf("failed to add resolv.conf: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
client/internal/debug/debug_nonunix.go
Normal file
7
client/internal/debug/debug_nonunix.go
Normal file
@@ -0,0 +1,7 @@
|
||||
//go:build !unix
|
||||
|
||||
package debug
|
||||
|
||||
func (g *BundleGenerator) addDNSInfo() error {
|
||||
return nil
|
||||
}
|
||||
29
client/internal/debug/debug_unix.go
Normal file
29
client/internal/debug/debug_unix.go
Normal file
@@ -0,0 +1,29 @@
|
||||
//go:build unix && !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const resolvConfPath = "/etc/resolv.conf"
|
||||
|
||||
func (g *BundleGenerator) addResolvConf() error {
|
||||
data, err := os.ReadFile(resolvConfPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read %s: %w", resolvConfPath, err)
|
||||
}
|
||||
|
||||
content := string(data)
|
||||
if g.anonymize {
|
||||
content = g.anonymizer.AnonymizeString(content)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(strings.NewReader(content), "resolv.conf"); err != nil {
|
||||
return fmt.Errorf("add resolv.conf to zip: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
|
||||
Scope string
|
||||
// UseIDToken indicates if the id token should be used for authentication
|
||||
UseIDToken bool
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -65,8 +65,9 @@ type hostManagerWithOriginalNS interface {
|
||||
|
||||
// DefaultServer dns server object
|
||||
type DefaultServer struct {
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
shutdownWg sync.WaitGroup
|
||||
// disableSys disables system DNS management (e.g., /etc/resolv.conf updates) while keeping the DNS service running.
|
||||
// This is different from ServiceEnable=false from management which completely disables the DNS service.
|
||||
disableSys bool
|
||||
@@ -318,6 +319,7 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.ctxCancel()
|
||||
s.shutdownWg.Wait()
|
||||
|
||||
s.mux.Lock()
|
||||
defer s.mux.Unlock()
|
||||
@@ -507,8 +509,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
||||
|
||||
s.applyHostConfig()
|
||||
|
||||
s.shutdownWg.Add(1)
|
||||
go func() {
|
||||
// persist dns state right away
|
||||
defer s.shutdownWg.Done()
|
||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||
log.Errorf("Failed to persist dns state: %v", err)
|
||||
}
|
||||
|
||||
@@ -944,7 +944,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||
pf, err := uspfilter.Create(wgIface, false, flowLogger, iface.DefaultMTU)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create uspfilter: %v", err)
|
||||
return nil, err
|
||||
|
||||
@@ -83,4 +83,3 @@ func TestCacheMiss(t *testing.T) {
|
||||
t.Fatalf("expected cache miss, got=%v ok=%v", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -33,7 +34,7 @@ type firewaller interface {
|
||||
}
|
||||
|
||||
type DNSForwarder struct {
|
||||
listenAddress string
|
||||
listenAddress netip.AddrPort
|
||||
ttl uint32
|
||||
statusRecorder *peer.Status
|
||||
|
||||
@@ -47,9 +48,11 @@ type DNSForwarder struct {
|
||||
firewall firewaller
|
||||
resolver resolver
|
||||
cache *cache
|
||||
|
||||
wgIface wgIface
|
||||
}
|
||||
|
||||
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
|
||||
func NewDNSForwarder(listenAddress netip.AddrPort, ttl uint32, firewall firewaller, statusRecorder *peer.Status, wgIface wgIface) *DNSForwarder {
|
||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||
return &DNSForwarder{
|
||||
listenAddress: listenAddress,
|
||||
@@ -58,30 +61,46 @@ func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, stat
|
||||
statusRecorder: statusRecorder,
|
||||
resolver: net.DefaultResolver,
|
||||
cache: newCache(),
|
||||
wgIface: wgIface,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
log.Infof("starting DNS forwarder on address=%s", f.listenAddress)
|
||||
var netstackNet *netstack.Net
|
||||
if f.wgIface != nil {
|
||||
netstackNet = f.wgIface.GetNet()
|
||||
}
|
||||
|
||||
addrDesc := f.listenAddress.String()
|
||||
if netstackNet != nil {
|
||||
addrDesc = fmt.Sprintf("netstack %s", f.listenAddress)
|
||||
}
|
||||
log.Infof("starting DNS forwarder on address=%s", addrDesc)
|
||||
|
||||
udpLn, err := f.createUDPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create UDP listener: %w", err)
|
||||
}
|
||||
|
||||
tcpLn, err := f.createTCPListener(netstackNet)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create TCP listener: %w", err)
|
||||
}
|
||||
|
||||
// UDP server
|
||||
mux := dns.NewServeMux()
|
||||
f.mux = mux
|
||||
mux.HandleFunc(".", f.handleDNSQueryUDP)
|
||||
f.dnsServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "udp",
|
||||
Handler: mux,
|
||||
PacketConn: udpLn,
|
||||
Handler: mux,
|
||||
}
|
||||
|
||||
// TCP server
|
||||
tcpMux := dns.NewServeMux()
|
||||
f.tcpMux = tcpMux
|
||||
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
|
||||
f.tcpServer = &dns.Server{
|
||||
Addr: f.listenAddress,
|
||||
Net: "tcp",
|
||||
Handler: tcpMux,
|
||||
Listener: tcpLn,
|
||||
Handler: tcpMux,
|
||||
}
|
||||
|
||||
f.UpdateDomains(entries)
|
||||
@@ -89,18 +108,33 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
log.Infof("DNS UDP listener running on %s", f.listenAddress)
|
||||
errCh <- f.dnsServer.ListenAndServe()
|
||||
log.Infof("DNS UDP listener running on %s", addrDesc)
|
||||
errCh <- f.dnsServer.ActivateAndServe()
|
||||
}()
|
||||
go func() {
|
||||
log.Infof("DNS TCP listener running on %s", f.listenAddress)
|
||||
errCh <- f.tcpServer.ListenAndServe()
|
||||
log.Infof("DNS TCP listener running on %s", addrDesc)
|
||||
errCh <- f.tcpServer.ActivateAndServe()
|
||||
}()
|
||||
|
||||
// return the first error we get (e.g. bind failure or shutdown)
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createUDPListener(netstackNet *netstack.Net) (net.PacketConn, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenUDPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenUDP("udp", net.UDPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) createTCPListener(netstackNet *netstack.Net) (net.Listener, error) {
|
||||
if netstackNet != nil {
|
||||
return netstackNet.ListenTCPAddrPort(f.listenAddress)
|
||||
}
|
||||
|
||||
return net.ListenTCP("tcp", net.TCPAddrFromAddrPort(f.listenAddress))
|
||||
}
|
||||
|
||||
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
@@ -297,7 +297,7 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
||||
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
|
||||
}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString(tt.configuredDomain)
|
||||
@@ -402,7 +402,7 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
// Set up forwarder
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Create entries and track sets
|
||||
@@ -489,7 +489,7 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Configure a single domain
|
||||
@@ -584,7 +584,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
d, err := domain.FromString(tt.configured)
|
||||
require.NoError(t, err)
|
||||
@@ -616,7 +616,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
||||
func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// Test that large UDP responses are truncated with TC bit set
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, _ := domain.FromString("example.com")
|
||||
@@ -652,7 +652,7 @@ func TestDNSForwarder_TCPTruncation(t *testing.T) {
|
||||
// a subsequent upstream failure still returns a successful response from cache.
|
||||
func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
@@ -696,7 +696,7 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
||||
// Verifies that cache normalization works across casing and trailing dot variations.
|
||||
func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
||||
mockResolver := &MockResolver{}
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("ExAmPlE.CoM")
|
||||
@@ -742,7 +742,7 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
// Set up complex overlapping patterns
|
||||
@@ -804,7 +804,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
mockFirewall := &MockFirewall{}
|
||||
mockResolver := &MockResolver{}
|
||||
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, mockFirewall, &peer.Status{}, nil)
|
||||
forwarder.resolver = mockResolver
|
||||
|
||||
d, err := domain.FromString("example.com")
|
||||
@@ -925,7 +925,7 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
||||
|
||||
func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
||||
// Test handling of malformed query with no questions
|
||||
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
|
||||
forwarder := NewDNSForwarder(netip.MustParseAddrPort("127.0.0.1:0"), 300, nil, &peer.Status{}, nil)
|
||||
|
||||
query := &dns.Msg{}
|
||||
// Don't set any question
|
||||
|
||||
@@ -10,9 +10,12 @@ import (
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
@@ -24,6 +27,12 @@ const (
|
||||
envServerPort = "NB_DNS_FORWARDER_PORT"
|
||||
)
|
||||
|
||||
// wgIface defines the interface for WireGuard interface operations needed by the DNS forwarder.
|
||||
type wgIface interface {
|
||||
GetNet() *netstack.Net
|
||||
Address() wgaddr.Address
|
||||
}
|
||||
|
||||
// ForwarderEntry is a mapping from a domain to a resource ID and a hash of the parent domain list.
|
||||
type ForwarderEntry struct {
|
||||
Domain domain.Domain
|
||||
@@ -34,7 +43,7 @@ type ForwarderEntry struct {
|
||||
type Manager struct {
|
||||
firewall firewall.Manager
|
||||
statusRecorder *peer.Status
|
||||
localAddr netip.Addr
|
||||
wgIface wgIface
|
||||
serverPort uint16
|
||||
|
||||
fwRules []firewall.Rule
|
||||
@@ -42,7 +51,7 @@ type Manager struct {
|
||||
dnsForwarder *DNSForwarder
|
||||
}
|
||||
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr netip.Addr) *Manager {
|
||||
func NewManager(fw firewall.Manager, statusRecorder *peer.Status, wgIface wgIface) *Manager {
|
||||
serverPort := nbdns.ForwarderServerPort
|
||||
if envPort := os.Getenv(envServerPort); envPort != "" {
|
||||
if port, err := strconv.ParseUint(envPort, 10, 16); err == nil && port > 0 {
|
||||
@@ -56,7 +65,7 @@ func NewManager(fw firewall.Manager, statusRecorder *peer.Status, localAddr neti
|
||||
return &Manager{
|
||||
firewall: fw,
|
||||
statusRecorder: statusRecorder,
|
||||
localAddr: localAddr,
|
||||
wgIface: wgIface,
|
||||
serverPort: serverPort,
|
||||
}
|
||||
}
|
||||
@@ -71,21 +80,25 @@ func (m *Manager) Start(fwdEntries []*ForwarderEntry) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
localAddr := m.wgIface.Address().IP
|
||||
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS UDP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
log.Infof("added DNS UDP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
|
||||
if err := m.firewall.AddInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
if err := m.firewall.AddInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
log.Warnf("failed to add DNS TCP DNAT rule: %v", err)
|
||||
} else {
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", m.localAddr, nbdns.ForwarderClientPort, m.localAddr, m.serverPort)
|
||||
log.Infof("added DNS TCP DNAT rule: %s:%d -> %s:%d", localAddr, nbdns.ForwarderClientPort, localAddr, m.serverPort)
|
||||
}
|
||||
}
|
||||
|
||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", m.serverPort), dnsTTL, m.firewall, m.statusRecorder)
|
||||
listenAddress := netip.AddrPortFrom(localAddr, m.serverPort)
|
||||
m.dnsForwarder = NewDNSForwarder(listenAddress, dnsTTL, m.firewall, m.statusRecorder, m.wgIface)
|
||||
|
||||
go func() {
|
||||
if err := m.dnsForwarder.Listen(fwdEntries); err != nil {
|
||||
// todo handle close error if it is exists
|
||||
@@ -111,16 +124,19 @@ func (m *Manager) Stop(ctx context.Context) error {
|
||||
|
||||
var mErr *multierror.Error
|
||||
|
||||
if m.localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
localAddr := m.wgIface.Address().IP
|
||||
if localAddr.IsValid() && m.firewall != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolUDP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS UDP DNAT rule: %w", err))
|
||||
}
|
||||
|
||||
if err := m.firewall.RemoveInboundDNAT(m.localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
if err := m.firewall.RemoveInboundDNAT(localAddr, firewall.ProtocolTCP, nbdns.ForwarderClientPort, m.serverPort); err != nil {
|
||||
mErr = multierror.Append(mErr, fmt.Errorf("remove DNS TCP DNAT rule: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
m.unregisterNetstackServices()
|
||||
|
||||
if err := m.dropDNSFirewall(); err != nil {
|
||||
mErr = multierror.Append(mErr, err)
|
||||
}
|
||||
@@ -145,21 +161,50 @@ func (m *Manager) allowDNSFirewall() error {
|
||||
|
||||
dnsRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("add udp firewall rule: %w", err)
|
||||
}
|
||||
m.fwRules = dnsRules
|
||||
|
||||
tcpRules, err := m.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolTCP, nil, dport, firewall.ActionAccept, "")
|
||||
if err != nil {
|
||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("add tcp firewall rule: %w", err)
|
||||
}
|
||||
|
||||
if err := m.firewall.Flush(); err != nil {
|
||||
return fmt.Errorf("flush: %w", err)
|
||||
}
|
||||
|
||||
m.fwRules = dnsRules
|
||||
m.tcpRules = tcpRules
|
||||
|
||||
m.registerNetstackServices()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) registerNetstackServices() {
|
||||
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := m.firewall.(interface {
|
||||
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.RegisterNetstackService(nftypes.TCP, m.serverPort)
|
||||
registrar.RegisterNetstackService(nftypes.UDP, m.serverPort)
|
||||
log.Debugf("registered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) unregisterNetstackServices() {
|
||||
if netstackNet := m.wgIface.GetNet(); netstackNet != nil {
|
||||
if registrar, ok := m.firewall.(interface {
|
||||
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
|
||||
}); ok {
|
||||
registrar.UnregisterNetstackService(nftypes.TCP, m.serverPort)
|
||||
registrar.UnregisterNetstackService(nftypes.UDP, m.serverPort)
|
||||
log.Debugf("unregistered DNS forwarder service with netstack for UDP/TCP:%d", m.serverPort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) dropDNSFirewall() error {
|
||||
var mErr *multierror.Error
|
||||
for _, rule := range m.fwRules {
|
||||
|
||||
@@ -148,6 +148,8 @@ type Engine struct {
|
||||
|
||||
// syncMsgMux is used to guarantee sequential Management Service message processing
|
||||
syncMsgMux *sync.Mutex
|
||||
// sshMux protects sshServer field access
|
||||
sshMux sync.Mutex
|
||||
|
||||
config *EngineConfig
|
||||
mobileDep MobileDependency
|
||||
@@ -200,8 +202,10 @@ type Engine struct {
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
wgIfaceMonitorWg sync.WaitGroup
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
|
||||
// shutdownWg tracks all long-running goroutines to ensure clean shutdown
|
||||
shutdownWg sync.WaitGroup
|
||||
|
||||
probeStunTurn *relay.StunTurnProbe
|
||||
}
|
||||
@@ -298,17 +302,12 @@ func (e *Engine) Stop() error {
|
||||
e.ingressGatewayMgr = nil
|
||||
}
|
||||
|
||||
e.stopDNSForwarder()
|
||||
|
||||
if e.routeManager != nil {
|
||||
e.routeManager.Stop(e.stateManager)
|
||||
}
|
||||
|
||||
if e.dnsForwardMgr != nil {
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
if e.srWatcher != nil {
|
||||
e.srWatcher.Close()
|
||||
}
|
||||
@@ -325,10 +324,6 @@ func (e *Engine) Stop() error {
|
||||
e.cancel()
|
||||
}
|
||||
|
||||
// very ugly but we want to remove peers from the WireGuard interface first before removing interface.
|
||||
// Removing peers happens in the conn.Close() asynchronously
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
e.close()
|
||||
|
||||
// stop flow manager after wg interface is gone
|
||||
@@ -336,8 +331,6 @@ func (e *Engine) Stop() error {
|
||||
e.flowManager.Close()
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -348,12 +341,52 @@ func (e *Engine) Stop() error {
|
||||
log.Errorf("failed to persist state: %v", err)
|
||||
}
|
||||
|
||||
// Stop WireGuard interface monitor and wait for it to exit
|
||||
e.wgIfaceMonitorWg.Wait()
|
||||
timeout := e.calculateShutdownTimeout()
|
||||
log.Debugf("waiting for goroutines to finish with timeout: %v", timeout)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := waitWithContext(shutdownCtx, &e.shutdownWg); err != nil {
|
||||
log.Warnf("shutdown timeout exceeded after %v, some goroutines may still be running", timeout)
|
||||
}
|
||||
|
||||
log.Infof("stopped Netbird Engine")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// calculateShutdownTimeout returns shutdown timeout: 10s base + 100ms per peer, capped at 30s.
|
||||
func (e *Engine) calculateShutdownTimeout() time.Duration {
|
||||
peerCount := len(e.peerStore.PeersPubKey())
|
||||
|
||||
baseTimeout := 10 * time.Second
|
||||
perPeerTimeout := time.Duration(peerCount) * 100 * time.Millisecond
|
||||
timeout := baseTimeout + perPeerTimeout
|
||||
|
||||
maxTimeout := 30 * time.Second
|
||||
if timeout > maxTimeout {
|
||||
timeout = maxTimeout
|
||||
}
|
||||
|
||||
return timeout
|
||||
}
|
||||
|
||||
// waitWithContext waits for WaitGroup with timeout, returns ctx.Err() on timeout.
|
||||
func waitWithContext(ctx context.Context, wg *sync.WaitGroup) error {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
return nil
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// Start creates a new WireGuard tunnel interface and listens to events from Signal and Management services
|
||||
// Connections to remote peers are not established here.
|
||||
// However, they will be established once an event with a list of peers to connect to will be received from Management Service
|
||||
@@ -483,14 +516,14 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
|
||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||
e.wgIfaceMonitorWg.Add(1)
|
||||
e.shutdownWg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer e.wgIfaceMonitorWg.Done()
|
||||
defer e.shutdownWg.Done()
|
||||
|
||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||
e.restartEngine()
|
||||
e.triggerClientRestart()
|
||||
} else if err != nil {
|
||||
log.Warnf("WireGuard interface monitor: %s", err)
|
||||
}
|
||||
@@ -506,7 +539,7 @@ func (e *Engine) createFirewall() error {
|
||||
}
|
||||
|
||||
var err error
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
|
||||
if err != nil || e.firewall == nil {
|
||||
log.Errorf("failed creating firewall manager: %s", err)
|
||||
return nil
|
||||
@@ -674,9 +707,11 @@ func (e *Engine) removeAllPeers() error {
|
||||
func (e *Engine) removePeer(peerKey string) error {
|
||||
log.Debugf("removing peer from engine %s", peerKey)
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
e.connMgr.RemovePeerConn(peerKey)
|
||||
|
||||
@@ -878,6 +913,7 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
log.Warnf("running SSH server on %s is not supported", runtime.GOOS)
|
||||
return nil
|
||||
}
|
||||
e.sshMux.Lock()
|
||||
// start SSH server if it wasn't running
|
||||
if isNil(e.sshServer) {
|
||||
listenAddr := fmt.Sprintf("%s:%d", e.wgInterface.Address().IP.String(), nbssh.DefaultSSHPort)
|
||||
@@ -885,34 +921,42 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||
listenAddr = fmt.Sprintf("127.0.0.1:%d", nbssh.DefaultSSHPort)
|
||||
}
|
||||
// nil sshServer means it has not yet been started
|
||||
var err error
|
||||
e.sshServer, err = e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
|
||||
server, err := e.sshServerFunc(e.config.SSHKey, listenAddr)
|
||||
if err != nil {
|
||||
e.sshMux.Unlock()
|
||||
return fmt.Errorf("create ssh server: %w", err)
|
||||
}
|
||||
|
||||
e.sshServer = server
|
||||
e.sshMux.Unlock()
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
err = e.sshServer.Start()
|
||||
err = server.Start()
|
||||
if err != nil {
|
||||
// will throw error when we stop it even if it is a graceful stop
|
||||
log.Debugf("stopped SSH server with error %v", err)
|
||||
}
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
e.sshMux.Lock()
|
||||
e.sshServer = nil
|
||||
e.sshMux.Unlock()
|
||||
log.Infof("stopped SSH server")
|
||||
}()
|
||||
} else {
|
||||
e.sshMux.Unlock()
|
||||
log.Debugf("SSH server is already running")
|
||||
}
|
||||
} else if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
} else {
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
// Disable SSH server request, so stop it if it was running
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed to stop SSH server %v", err)
|
||||
}
|
||||
e.sshServer = nil
|
||||
}
|
||||
e.sshServer = nil
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -949,7 +993,9 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
// receiveManagementEvents connects to the Management Service event stream to receive updates from the management service
|
||||
// E.g. when a new peer has been registered and we are allowed to connect to it.
|
||||
func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
@@ -1059,10 +1105,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
protoDNSConfig = &mgmProto.DNSConfig{}
|
||||
}
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)); err != nil {
|
||||
dnsConfig := toDNSConfig(protoDNSConfig, e.wgInterface.Address().Network)
|
||||
|
||||
if err := e.dnsServer.UpdateDNSServer(serial, dnsConfig); err != nil {
|
||||
log.Errorf("failed to update dns server, err: %v", err)
|
||||
}
|
||||
|
||||
e.routeManager.SetDNSForwarderPort(dnsConfig.ForwarderPort)
|
||||
|
||||
// apply routes first, route related actions might depend on routing being enabled
|
||||
routes := toRoutes(networkMap.GetRoutes())
|
||||
serverRoutes, clientRoutes := e.routeManager.ClassifyRoutes(routes)
|
||||
@@ -1121,6 +1171,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
e.statusRecorder.FinishPeerListModifications()
|
||||
|
||||
// update SSHServer by adding remote peer SSH keys
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
for _, config := range networkMap.GetRemotePeers() {
|
||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
||||
@@ -1131,6 +1182,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
}
|
||||
}
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
}
|
||||
|
||||
// must set the exclude list after the peers are added. Without it the manager can not figure out the peers parameters from the store
|
||||
@@ -1207,10 +1259,16 @@ func toRouteDomains(myPubKey string, routes []*route.Route) []*dnsfwd.ForwarderE
|
||||
}
|
||||
|
||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network netip.Prefix) nbdns.Config {
|
||||
forwarderPort := uint16(protoDNSConfig.GetForwarderPort())
|
||||
if forwarderPort == 0 {
|
||||
forwarderPort = nbdns.ForwarderClientPort
|
||||
}
|
||||
|
||||
dnsUpdate := nbdns.Config{
|
||||
ServiceEnable: protoDNSConfig.GetServiceEnable(),
|
||||
CustomZones: make([]nbdns.CustomZone, 0),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||
ForwarderPort: forwarderPort,
|
||||
}
|
||||
|
||||
for _, zone := range protoDNSConfig.GetCustomZones() {
|
||||
@@ -1367,7 +1425,9 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
|
||||
// receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers
|
||||
func (e *Engine) receiveSignalEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
e.syncMsgMux.Lock()
|
||||
@@ -1484,12 +1544,14 @@ func (e *Engine) close() {
|
||||
e.statusRecorder.SetWgIface(nil)
|
||||
}
|
||||
|
||||
e.sshMux.Lock()
|
||||
if !isNil(e.sshServer) {
|
||||
err := e.sshServer.Stop()
|
||||
if err != nil {
|
||||
log.Warnf("failed stopping the SSH server: %v", err)
|
||||
}
|
||||
}
|
||||
e.sshMux.Unlock()
|
||||
|
||||
if e.firewall != nil {
|
||||
err := e.firewall.Close(e.stateManager)
|
||||
@@ -1720,8 +1782,10 @@ func (e *Engine) RunHealthProbes(waitForResult bool) bool {
|
||||
return allHealthy
|
||||
}
|
||||
|
||||
// restartEngine restarts the engine by cancelling the client context
|
||||
func (e *Engine) restartEngine() {
|
||||
// triggerClientRestart triggers a full client restart by cancelling the client context.
|
||||
// Note: This does NOT just restart the engine - it cancels the entire client context,
|
||||
// which causes the connect client's retry loop to create a completely new engine.
|
||||
func (e *Engine) triggerClientRestart() {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
@@ -1743,7 +1807,9 @@ func (e *Engine) startNetworkMonitor() {
|
||||
}
|
||||
|
||||
e.networkMonitor = networkmonitor.New()
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
if err := e.networkMonitor.Listen(e.ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Infof("network monitor stopped")
|
||||
@@ -1753,8 +1819,8 @@ func (e *Engine) startNetworkMonitor() {
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("Network monitor: detected network change, restarting engine")
|
||||
e.restartEngine()
|
||||
log.Infof("Network monitor: detected network change, triggering client restart")
|
||||
e.triggerClientRestart()
|
||||
}()
|
||||
}
|
||||
|
||||
@@ -1845,38 +1911,46 @@ func (e *Engine) updateDNSForwarder(
|
||||
}
|
||||
|
||||
if !enabled {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.stopDNSForwarder()
|
||||
return
|
||||
}
|
||||
|
||||
if len(fwdEntries) > 0 {
|
||||
if e.dnsForwardMgr == nil {
|
||||
localAddr := e.wgInterface.Address().IP
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, localAddr)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
e.startDNSForwarder(fwdEntries)
|
||||
} else {
|
||||
e.dnsForwardMgr.UpdateDomains(fwdEntries)
|
||||
}
|
||||
} else if e.dnsForwardMgr != nil {
|
||||
log.Infof("disable domain router service")
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
e.dnsForwardMgr = nil
|
||||
e.stopDNSForwarder()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Engine) startDNSForwarder(fwdEntries []*dnsfwd.ForwarderEntry) {
|
||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder, e.wgInterface)
|
||||
|
||||
if err := e.dnsForwardMgr.Start(fwdEntries); err != nil {
|
||||
log.Errorf("failed to start DNS forward: %v", err)
|
||||
e.dnsForwardMgr = nil
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("started domain router service with %d entries", len(fwdEntries))
|
||||
}
|
||||
|
||||
func (e *Engine) stopDNSForwarder() {
|
||||
if e.dnsForwardMgr == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := e.dnsForwardMgr.Stop(context.Background()); err != nil {
|
||||
log.Errorf("failed to stop DNS forward: %v", err)
|
||||
}
|
||||
|
||||
e.dnsForwardMgr = nil
|
||||
}
|
||||
|
||||
func (e *Engine) GetNet() (*netstack.Net, error) {
|
||||
e.syncMsgMux.Lock()
|
||||
intf := e.wgInterface
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
// Manager handles netflow tracking and logging
|
||||
type Manager struct {
|
||||
mux sync.Mutex
|
||||
shutdownWg sync.WaitGroup
|
||||
logger nftypes.FlowLogger
|
||||
flowConfig *nftypes.FlowConfig
|
||||
conntrack nftypes.ConnTracker
|
||||
@@ -105,8 +106,15 @@ func (m *Manager) resetClient() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
m.cancel = cancel
|
||||
|
||||
go m.receiveACKs(ctx, flowClient)
|
||||
go m.startSender(ctx)
|
||||
m.shutdownWg.Add(2)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
m.receiveACKs(ctx, flowClient)
|
||||
}()
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
m.startSender(ctx)
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -176,11 +184,12 @@ func (m *Manager) Update(update *nftypes.FlowConfig) error {
|
||||
// Close cleans up all resources
|
||||
func (m *Manager) Close() {
|
||||
m.mux.Lock()
|
||||
defer m.mux.Unlock()
|
||||
|
||||
if err := m.disableFlow(); err != nil {
|
||||
log.Warnf("failed to disable flow manager: %v", err)
|
||||
}
|
||||
m.mux.Unlock()
|
||||
|
||||
m.shutdownWg.Wait()
|
||||
}
|
||||
|
||||
// GetLogger returns the flow logger
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (darwin && !ios) || dragonfly || freebsd || netbsd || openbsd
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd
|
||||
|
||||
package networkmonitor
|
||||
|
||||
@@ -6,21 +6,19 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
fd, err := prepareFd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := unix.Close(fd)
|
||||
if err != nil && !errors.Is(err, unix.EBADF) {
|
||||
@@ -28,72 +26,5 @@ func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) er
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
return routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||
}
|
||||
|
||||
92
client/internal/networkmonitor/check_change_common.go
Normal file
92
client/internal/networkmonitor/check_change_common.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build dragonfly || freebsd || netbsd || openbsd || darwin
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
func prepareFd() (int, error) {
|
||||
return unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
}
|
||||
|
||||
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, buf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse RIB: %v", err)
|
||||
}
|
||||
|
||||
if len(msgs) != 1 {
|
||||
return nil, fmt.Errorf("unexpected RIB message msgs: %v", msgs)
|
||||
}
|
||||
|
||||
msg, ok := msgs[0].(*route.RouteMessage)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unexpected RIB message type: %T", msgs[0])
|
||||
}
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
149
client/internal/networkmonitor/check_change_darwin.go
Normal file
149
client/internal/networkmonitor/check_change_darwin.go
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package networkmonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
)
|
||||
|
||||
// todo: refactor to not use static functions
|
||||
|
||||
func checkChange(ctx context.Context, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
fd, err := prepareFd()
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %v", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil {
|
||||
if !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("Network monitor: failed to close routing socket: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
routeChanged := make(chan struct{})
|
||||
go func() {
|
||||
_ = routeCheck(ctx, fd, nexthopv4, nexthopv6)
|
||||
close(routeChanged)
|
||||
}()
|
||||
|
||||
wakeUp := make(chan struct{})
|
||||
go func() {
|
||||
wakeUpListen(ctx)
|
||||
close(wakeUp)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-routeChanged:
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
log.Infof("route change detected")
|
||||
return nil
|
||||
case <-wakeUp:
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
log.Infof("wakeup detected")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func wakeUpListen(ctx context.Context) {
|
||||
log.Infof("start to watch for system wakeups")
|
||||
var (
|
||||
initialHash uint32
|
||||
err error
|
||||
)
|
||||
|
||||
// Keep retrying until initial sysctl succeeds or context is canceled
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
default:
|
||||
initialHash, err = readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to detect initial sleep time: %v", err)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("exit from wakeUpListen initial hash detection due to context cancellation")
|
||||
return
|
||||
case <-time.After(3 * time.Second):
|
||||
continue
|
||||
}
|
||||
}
|
||||
log.Debugf("initial wakeup hash: %d", initialHash)
|
||||
break
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Info("context canceled, stopping wakeUpListen")
|
||||
return
|
||||
|
||||
case <-ticker.C:
|
||||
newHash, err := readSleepTimeHash()
|
||||
if err != nil {
|
||||
log.Errorf("failed to read sleep time hash: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if newHash == initialHash {
|
||||
log.Tracef("no wakeup detected")
|
||||
continue
|
||||
}
|
||||
|
||||
upOut, err := exec.Command("uptime").Output()
|
||||
if err != nil {
|
||||
log.Errorf("failed to run uptime command: %v", err)
|
||||
upOut = []byte("unknown")
|
||||
}
|
||||
log.Infof("Wakeup detected: %d -> %d, uptime: %s", initialHash, newHash, upOut)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func readSleepTimeHash() (uint32, error) {
|
||||
cmd := exec.Command("sysctl", "kern.sleeptime")
|
||||
out, err := cmd.Output()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to run sysctl: %w", err)
|
||||
}
|
||||
|
||||
h, err := hash(out)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to compute hash: %w", err)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
func hash(data []byte) (uint32, error) {
|
||||
hasher := fnv.New32a() // Create a new 32-bit FNV-1a hasher
|
||||
if _, err := hasher.Write(data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return hasher.Sum32(), nil
|
||||
}
|
||||
@@ -88,6 +88,7 @@ func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||
event := make(chan struct{}, 1)
|
||||
go nw.checkChanges(ctx, event, nexthop4, nexthop6)
|
||||
|
||||
log.Infof("start watching for network changes")
|
||||
// debounce changes
|
||||
timer := time.NewTimer(0)
|
||||
timer.Stop()
|
||||
|
||||
@@ -19,11 +19,10 @@ type SRWatcher struct {
|
||||
signalClient chNotifier
|
||||
relayManager chNotifier
|
||||
|
||||
listeners map[chan struct{}]struct{}
|
||||
mu sync.Mutex
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig ice.Config
|
||||
|
||||
listeners map[chan struct{}]struct{}
|
||||
mu sync.Mutex
|
||||
iFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
iceConfig ice.Config
|
||||
cancelIceMonitor context.CancelFunc
|
||||
}
|
||||
|
||||
|
||||
@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
|
||||
|
||||
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
|
||||
if isController(w.config) {
|
||||
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
} else {
|
||||
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
|
||||
}
|
||||
|
||||
@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
|
||||
DisablePromptLogin bool
|
||||
// LoginFlag is used to configure the PKCE flow login behavior
|
||||
LoginFlag common.LoginFlag
|
||||
// LoginHint is used to pre-fill the email/username field during authentication
|
||||
LoginHint string
|
||||
}
|
||||
|
||||
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||
@@ -25,4 +26,5 @@ type HandlerParams struct {
|
||||
UseNewDNSRoute bool
|
||||
Firewall manager.Manager
|
||||
FakeIPManager *fakeip.Manager
|
||||
ForwarderPort *atomic.Uint32
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
@@ -20,7 +21,6 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
pkgdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/common"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/fakeip"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||
@@ -55,6 +55,7 @@ type DnsInterceptor struct {
|
||||
peerStore *peerstore.Store
|
||||
firewall firewall.Manager
|
||||
fakeIPManager *fakeip.Manager
|
||||
forwarderPort *atomic.Uint32
|
||||
}
|
||||
|
||||
func New(params common.HandlerParams) *DnsInterceptor {
|
||||
@@ -69,6 +70,7 @@ func New(params common.HandlerParams) *DnsInterceptor {
|
||||
firewall: params.Firewall,
|
||||
fakeIPManager: params.FakeIPManager,
|
||||
interceptedDomains: make(domainMap),
|
||||
forwarderPort: params.ForwarderPort,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -257,7 +259,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), pkgdns.ForwarderClientPort)
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"runtime"
|
||||
"slices"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -23,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/client"
|
||||
@@ -54,6 +56,7 @@ type Manager interface {
|
||||
SetRouteChangeListener(listener listener.NetworkChangeListener)
|
||||
InitialRouteRange() []string
|
||||
SetFirewall(firewall.Manager) error
|
||||
SetDNSForwarderPort(port uint16)
|
||||
Stop(stateManager *statemanager.Manager)
|
||||
}
|
||||
|
||||
@@ -78,6 +81,7 @@ type DefaultManager struct {
|
||||
ctx context.Context
|
||||
stop context.CancelFunc
|
||||
mux sync.Mutex
|
||||
shutdownWg sync.WaitGroup
|
||||
clientNetworks map[route.HAUniqueID]*client.Watcher
|
||||
routeSelector *routeselector.RouteSelector
|
||||
serverRouter *server.Router
|
||||
@@ -101,6 +105,7 @@ type DefaultManager struct {
|
||||
disableServerRoutes bool
|
||||
activeRoutes map[route.HAUniqueID]client.RouteHandler
|
||||
fakeIPManager *fakeip.Manager
|
||||
dnsForwarderPort atomic.Uint32
|
||||
}
|
||||
|
||||
func NewManager(config ManagerConfig) *DefaultManager {
|
||||
@@ -130,6 +135,7 @@ func NewManager(config ManagerConfig) *DefaultManager {
|
||||
disableServerRoutes: config.DisableServerRoutes,
|
||||
activeRoutes: make(map[route.HAUniqueID]client.RouteHandler),
|
||||
}
|
||||
dm.dnsForwarderPort.Store(uint32(nbdns.ForwarderClientPort))
|
||||
|
||||
useNoop := netstack.IsEnabled() || config.DisableClientRoutes
|
||||
dm.setupRefCounters(useNoop)
|
||||
@@ -270,9 +276,15 @@ func (m *DefaultManager) SetFirewall(firewall firewall.Manager) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDNSForwarderPort sets the DNS forwarder port for route handlers
|
||||
func (m *DefaultManager) SetDNSForwarderPort(port uint16) {
|
||||
m.dnsForwarderPort.Store(uint32(port))
|
||||
}
|
||||
|
||||
// Stop stops the manager watchers and clean firewall rules
|
||||
func (m *DefaultManager) Stop(stateManager *statemanager.Manager) {
|
||||
m.stop()
|
||||
m.shutdownWg.Wait()
|
||||
if m.serverRouter != nil {
|
||||
m.serverRouter.CleanUp()
|
||||
}
|
||||
@@ -345,6 +357,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
||||
UseNewDNSRoute: m.useNewDNSRoute,
|
||||
Firewall: m.firewall,
|
||||
FakeIPManager: m.fakeIPManager,
|
||||
ForwarderPort: &m.dnsForwarderPort,
|
||||
}
|
||||
handler := client.HandlerFromRoute(params)
|
||||
if err := handler.AddRoute(m.ctx); err != nil {
|
||||
@@ -474,7 +487,11 @@ func (m *DefaultManager) TriggerSelection(networks route.HAMap) {
|
||||
}
|
||||
clientNetworkWatcher := client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.Start()
|
||||
m.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
clientNetworkWatcher.Start()
|
||||
}()
|
||||
clientNetworkWatcher.SendUpdate(client.RoutesUpdate{Routes: routes})
|
||||
}
|
||||
|
||||
@@ -516,7 +533,11 @@ func (m *DefaultManager) updateClientNetworks(updateSerial uint64, networks rout
|
||||
}
|
||||
clientNetworkWatcher = client.NewWatcher(config)
|
||||
m.clientNetworks[id] = clientNetworkWatcher
|
||||
go clientNetworkWatcher.Start()
|
||||
m.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer m.shutdownWg.Done()
|
||||
clientNetworkWatcher.Start()
|
||||
}()
|
||||
}
|
||||
update := client.RoutesUpdate{
|
||||
UpdateSerial: updateSerial,
|
||||
|
||||
@@ -90,6 +90,10 @@ func (m *MockManager) SetFirewall(firewall.Manager) error {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
// SetDNSForwarderPort mock implementation of SetDNSForwarderPort from Manager interface
|
||||
func (m *MockManager) SetDNSForwarderPort(port uint16) {
|
||||
}
|
||||
|
||||
// Stop mock implementation of Stop from Manager interface
|
||||
func (m *MockManager) Stop(stateManager *statemanager.Manager) {
|
||||
if m.StopFunc != nil {
|
||||
|
||||
@@ -9,8 +9,6 @@ import (
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
@@ -128,13 +126,11 @@ func (rs *RouteSelector) IsSelected(routeID route.NetID) bool {
|
||||
defer rs.mu.RUnlock()
|
||||
|
||||
if rs.deselectAll {
|
||||
log.Debugf("Route %s not selected (deselect all)", routeID)
|
||||
return false
|
||||
}
|
||||
|
||||
_, deselected := rs.deselectedRoutes[routeID]
|
||||
isSelected := !deselected
|
||||
log.Debugf("Route %s selection status: %v (deselected: %v)", routeID, isSelected, deselected)
|
||||
return isSelected
|
||||
}
|
||||
|
||||
|
||||
@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
|
||||
ConfigPath: c.cfgFile,
|
||||
})
|
||||
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false)
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
@@ -17,8 +17,7 @@ type Conn struct {
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.Conn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *Conn) Close() error {
|
||||
return closeConn(c.ID, c.Conn)
|
||||
}
|
||||
@@ -29,7 +28,7 @@ type TCPConn struct {
|
||||
ID hooks.ConnectionID
|
||||
}
|
||||
|
||||
// Close overrides the net.TCPConn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.TCPConn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *TCPConn) Close() error {
|
||||
return closeConn(c.ID, c.TCPConn)
|
||||
}
|
||||
@@ -37,13 +36,16 @@ func (c *TCPConn) Close() error {
|
||||
// closeConn is a helper function to close connections and execute close hooks.
|
||||
func closeConn(id hooks.ConnectionID, conn io.Closer) error {
|
||||
err := conn.Close()
|
||||
cleanupConnID(id)
|
||||
return err
|
||||
}
|
||||
|
||||
// cleanupConnID executes close hooks for a connection ID.
|
||||
func cleanupConnID(id hooks.ConnectionID) {
|
||||
closeHooks := hooks.GetCloseHooks()
|
||||
for _, hook := range closeHooks {
|
||||
if err := hook(id); err != nil {
|
||||
log.Errorf("Error executing close hook: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -74,7 +74,6 @@ func DialTCP(network string, laddr, raddr *net.TCPAddr) (transport.TCPConn, erro
|
||||
}
|
||||
return &TCPConn{TCPConn: tcpConn, ID: c.ID}, nil
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Errorf("failed to close connection: %v", err)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,7 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (net.
|
||||
|
||||
conn, err := d.Dialer.DialContext(ctx, network, address)
|
||||
if err != nil {
|
||||
cleanupConnID(connID)
|
||||
return nil, fmt.Errorf("d.Dialer.DialContext: %w", err)
|
||||
}
|
||||
|
||||
@@ -64,7 +65,7 @@ func callDialerHooks(ctx context.Context, connID hooks.ConnectionID, address str
|
||||
|
||||
ips, err := resolver.LookupIPAddr(ctx, host)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to resolve address %s: %w", address, err)
|
||||
return fmt.Errorf("resolve address %s: %w", address, err)
|
||||
}
|
||||
|
||||
log.Debugf("Dialer resolved IPs for %s: %v", address, ips)
|
||||
|
||||
@@ -48,7 +48,7 @@ func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
return c.PacketConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.PacketConn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *PacketConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.PacketConn)
|
||||
@@ -69,7 +69,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
|
||||
return c.UDPConn.WriteTo(b, addr)
|
||||
}
|
||||
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks before closing the connection.
|
||||
// Close overrides the net.UDPConn Close method to execute all registered hooks after closing the connection.
|
||||
func (c *UDPConn) Close() error {
|
||||
defer c.seenAddrs.Clear()
|
||||
return closeConn(c.ID, c.UDPConn)
|
||||
|
||||
@@ -279,8 +279,10 @@ type LoginRequest struct {
|
||||
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
|
||||
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
|
||||
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
|
||||
unknownFields protoimpl.UnknownFields
|
||||
sizeCache protoimpl.SizeCache
|
||||
}
|
||||
|
||||
func (x *LoginRequest) Reset() {
|
||||
@@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (x *LoginRequest) GetHint() string {
|
||||
if x != nil && x.Hint != nil {
|
||||
return *x.Hint
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
type LoginResponse struct {
|
||||
state protoimpl.MessageState `protogen:"open.v1"`
|
||||
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
|
||||
@@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor
|
||||
const file_daemon_proto_rawDesc = "" +
|
||||
"\n" +
|
||||
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
|
||||
"\fEmptyRequest\"\xc3\x0e\n" +
|
||||
"\fEmptyRequest\"\xe5\x0e\n" +
|
||||
"\fLoginRequest\x12\x1a\n" +
|
||||
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
|
||||
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
|
||||
@@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
|
||||
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
|
||||
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
|
||||
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" +
|
||||
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
|
||||
"\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" +
|
||||
"\x11_rosenpassEnabledB\x10\n" +
|
||||
"\x0e_interfaceNameB\x10\n" +
|
||||
"\x0e_wireguardPortB\x17\n" +
|
||||
@@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" +
|
||||
"\x0e_block_inboundB\x0e\n" +
|
||||
"\f_profileNameB\v\n" +
|
||||
"\t_usernameB\x06\n" +
|
||||
"\x04_mtu\"\xb5\x01\n" +
|
||||
"\x04_mtuB\a\n" +
|
||||
"\x05_hint\"\xb5\x01\n" +
|
||||
"\rLoginResponse\x12$\n" +
|
||||
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
|
||||
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
|
||||
|
||||
@@ -158,6 +158,9 @@ message LoginRequest {
|
||||
optional string username = 31;
|
||||
|
||||
optional int64 mtu = 32;
|
||||
|
||||
// hint is used to pre-fill the email/username field during SSO authentication
|
||||
optional string hint = 33;
|
||||
}
|
||||
|
||||
message LoginResponse {
|
||||
|
||||
@@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
state.Set(internal.StatusConnecting)
|
||||
|
||||
if msg.SetupKey == "" {
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient)
|
||||
hint := ""
|
||||
if msg.Hint != nil {
|
||||
hint = *msg.Hint
|
||||
}
|
||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
|
||||
if err != nil {
|
||||
state.Set(internal.StatusLoginFailed)
|
||||
return nil, err
|
||||
|
||||
@@ -296,6 +296,8 @@ type serviceClient struct {
|
||||
mExitNodeDeselectAll *systray.MenuItem
|
||||
logFile string
|
||||
wLoginURL fyne.Window
|
||||
|
||||
connectCancel context.CancelFunc
|
||||
}
|
||||
|
||||
type menuHandler struct {
|
||||
@@ -592,17 +594,15 @@ func (s *serviceClient) getSettingsForm() *widget.Form {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||
func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginResponse, error) {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfile()
|
||||
if err != nil {
|
||||
log.Errorf("get active profile: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
currUser, err := user.Current()
|
||||
@@ -610,84 +610,80 @@ func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||
return nil, fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
loginResp, err := conn.Login(s.ctx, &proto.LoginRequest{
|
||||
loginReq := &proto.LoginRequest{
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
}
|
||||
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Errorf("login to management URL with: %v", err)
|
||||
return nil, err
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
loginReq.Hint = &profileState.Email
|
||||
}
|
||||
|
||||
loginResp, err := conn.Login(ctx, loginReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("login to management: %w", err)
|
||||
}
|
||||
|
||||
if loginResp.NeedsSSOLogin && openURL {
|
||||
err = s.handleSSOLogin(loginResp, conn)
|
||||
if err != nil {
|
||||
log.Errorf("handle SSO login failed: %v", err)
|
||||
return nil, err
|
||||
if err = s.handleSSOLogin(ctx, loginResp, conn); err != nil {
|
||||
return nil, fmt.Errorf("SSO login: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleSSOLogin(loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||
err := openURL(loginResp.VerificationURIComplete)
|
||||
if err != nil {
|
||||
log.Errorf("opening the verification uri in the browser failed: %v", err)
|
||||
return err
|
||||
func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.LoginResponse, conn proto.DaemonServiceClient) error {
|
||||
if err := openURL(loginResp.VerificationURIComplete); err != nil {
|
||||
return fmt.Errorf("open browser: %w", err)
|
||||
}
|
||||
|
||||
resp, err := conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
resp, err := conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode})
|
||||
if err != nil {
|
||||
log.Errorf("waiting sso login failed with: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("wait for SSO login: %w", err)
|
||||
}
|
||||
|
||||
if resp.Email != "" {
|
||||
err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ProfileState{
|
||||
Email: resp.Email,
|
||||
})
|
||||
if err != nil {
|
||||
log.Warnf("failed to set profile state: %v", err)
|
||||
}); err != nil {
|
||||
log.Debugf("failed to set profile state: %v", err)
|
||||
} else {
|
||||
s.mProfile.refresh()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) menuUpClick() error {
|
||||
func (s *serviceClient) menuUpClick(ctx context.Context) error {
|
||||
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
systray.SetTemplateIcon(iconErrorMacOS, s.icError)
|
||||
log.Errorf("get client: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.login(true)
|
||||
_, err = s.login(ctx, true)
|
||||
if err != nil {
|
||||
log.Errorf("login failed with: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get status: %w", err)
|
||||
}
|
||||
|
||||
if status.Status == string(internal.StatusConnected) {
|
||||
log.Warnf("already connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||
log.Errorf("up service: %v", err)
|
||||
return err
|
||||
if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||
return fmt.Errorf("start connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -697,24 +693,20 @@ func (s *serviceClient) menuDownClick() error {
|
||||
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
log.Errorf("get client: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get daemon client: %w", err)
|
||||
}
|
||||
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return err
|
||||
return fmt.Errorf("get status: %w", err)
|
||||
}
|
||||
|
||||
if status.Status != string(internal.StatusConnected) && status.Status != string(internal.StatusConnecting) {
|
||||
log.Warnf("already down")
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := s.conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
log.Errorf("down service: %v", err)
|
||||
return err
|
||||
if _, err := conn.Down(s.ctx, &proto.DownRequest{}); err != nil {
|
||||
return fmt.Errorf("stop connection: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -850,6 +842,7 @@ func (s *serviceClient) onTrayReady() {
|
||||
|
||||
newProfileMenuArgs := &newProfileMenuArgs{
|
||||
ctx: s.ctx,
|
||||
serviceClient: s,
|
||||
profileManager: s.profileManager,
|
||||
eventHandler: s.eventHandler,
|
||||
profileMenuItem: profileMenuItem,
|
||||
@@ -1381,7 +1374,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := s.login(false)
|
||||
resp, err := s.login(ctx, false)
|
||||
if err != nil {
|
||||
log.Errorf("failed to fetch login URL: %v", err)
|
||||
return
|
||||
@@ -1401,7 +1394,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.WaitSSOLogin(s.ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
|
||||
_, err = conn.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: resp.UserCode})
|
||||
if err != nil {
|
||||
log.Errorf("Waiting sso login failed with: %v", err)
|
||||
label.SetText("Waiting login failed, please create \na debug bundle in the settings and contact support.")
|
||||
@@ -1409,7 +1402,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
}
|
||||
|
||||
label.SetText("Re-authentication successful.\nReconnecting")
|
||||
status, err := conn.Status(s.ctx, &proto.StatusRequest{})
|
||||
status, err := conn.Status(ctx, &proto.StatusRequest{})
|
||||
if err != nil {
|
||||
log.Errorf("get service status: %v", err)
|
||||
return
|
||||
@@ -1422,7 +1415,7 @@ func (s *serviceClient) showLoginURL() context.CancelFunc {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = conn.Up(s.ctx, &proto.UpRequest{})
|
||||
_, err = conn.Up(ctx, &proto.UpRequest{})
|
||||
if err != nil {
|
||||
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")
|
||||
log.Errorf("Reconnecting failed with: %v", err)
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"fyne.io/fyne/v2"
|
||||
"fyne.io/systray"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
@@ -67,20 +69,55 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
|
||||
func (h *eventHandler) handleConnectClick() {
|
||||
h.client.mUp.Disable()
|
||||
|
||||
if h.client.connectCancel != nil {
|
||||
h.client.connectCancel()
|
||||
}
|
||||
|
||||
connectCtx, connectCancel := context.WithCancel(h.client.ctx)
|
||||
h.client.connectCancel = connectCancel
|
||||
|
||||
go func() {
|
||||
defer h.client.mUp.Enable()
|
||||
if err := h.client.menuUpClick(); err != nil {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird service"))
|
||||
defer connectCancel()
|
||||
|
||||
if err := h.client.menuUpClick(connectCtx); err != nil {
|
||||
st, ok := status.FromError(err)
|
||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||
log.Debugf("connect operation cancelled by user")
|
||||
} else {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
|
||||
log.Errorf("connect failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.client.updateStatus(); err != nil {
|
||||
log.Debugf("failed to update status after connect: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleDisconnectClick() {
|
||||
h.client.mDown.Disable()
|
||||
|
||||
if h.client.connectCancel != nil {
|
||||
log.Debugf("cancelling ongoing connect operation")
|
||||
h.client.connectCancel()
|
||||
h.client.connectCancel = nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer h.client.mDown.Enable()
|
||||
if err := h.client.menuDownClick(); err != nil {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect to NetBird daemon"))
|
||||
st, ok := status.FromError(err)
|
||||
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
||||
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
|
||||
log.Errorf("disconnect failed: %v", err)
|
||||
} else {
|
||||
log.Debugf("disconnect cancelled or already disconnecting")
|
||||
}
|
||||
}
|
||||
|
||||
if err := h.client.updateStatus(); err != nil {
|
||||
log.Debugf("failed to update status after disconnect: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -245,6 +282,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
|
||||
}
|
||||
|
||||
h.client.getSrvConfig()
|
||||
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -387,6 +387,7 @@ type subItem struct {
|
||||
type profileMenu struct {
|
||||
mu sync.Mutex
|
||||
ctx context.Context
|
||||
serviceClient *serviceClient
|
||||
profileManager *profilemanager.ProfileManager
|
||||
eventHandler *eventHandler
|
||||
profileMenuItem *systray.MenuItem
|
||||
@@ -396,7 +397,7 @@ type profileMenu struct {
|
||||
logoutSubItem *subItem
|
||||
profilesState []Profile
|
||||
downClickCallback func() error
|
||||
upClickCallback func() error
|
||||
upClickCallback func(context.Context) error
|
||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||
loadSettingsCallback func()
|
||||
app fyne.App
|
||||
@@ -404,12 +405,13 @@ type profileMenu struct {
|
||||
|
||||
type newProfileMenuArgs struct {
|
||||
ctx context.Context
|
||||
serviceClient *serviceClient
|
||||
profileManager *profilemanager.ProfileManager
|
||||
eventHandler *eventHandler
|
||||
profileMenuItem *systray.MenuItem
|
||||
emailMenuItem *systray.MenuItem
|
||||
downClickCallback func() error
|
||||
upClickCallback func() error
|
||||
upClickCallback func(context.Context) error
|
||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||
loadSettingsCallback func()
|
||||
app fyne.App
|
||||
@@ -418,6 +420,7 @@ type newProfileMenuArgs struct {
|
||||
func newProfileMenu(args newProfileMenuArgs) *profileMenu {
|
||||
p := profileMenu{
|
||||
ctx: args.ctx,
|
||||
serviceClient: args.serviceClient,
|
||||
profileManager: args.profileManager,
|
||||
eventHandler: args.eventHandler,
|
||||
profileMenuItem: args.profileMenuItem,
|
||||
@@ -569,10 +572,19 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
}
|
||||
|
||||
if err := p.upClickCallback(); err != nil {
|
||||
if p.serviceClient.connectCancel != nil {
|
||||
p.serviceClient.connectCancel()
|
||||
}
|
||||
|
||||
connectCtx, connectCancel := context.WithCancel(p.ctx)
|
||||
p.serviceClient.connectCancel = connectCancel
|
||||
|
||||
if err := p.upClickCallback(connectCtx); err != nil {
|
||||
log.Errorf("failed to handle up click after switching profile: %v", err)
|
||||
}
|
||||
|
||||
connectCancel()
|
||||
|
||||
p.refresh()
|
||||
p.loadSettingsCallback()
|
||||
}
|
||||
|
||||
@@ -35,6 +35,8 @@ type Config struct {
|
||||
NameServerGroups []*NameServerGroup
|
||||
// CustomZones contains a list of custom zone
|
||||
CustomZones []CustomZone
|
||||
// ForwarderPort is the port clients should connect to on routing peers for DNS forwarding
|
||||
ForwarderPort uint16
|
||||
}
|
||||
|
||||
// CustomZone represents a custom zone to be resolved by the dns server
|
||||
|
||||
14
go.mod
14
go.mod
@@ -56,6 +56,7 @@ require (
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
|
||||
github.com/hashicorp/go-version v1.6.0
|
||||
github.com/jackc/pgx/v5 v5.5.5
|
||||
github.com/libdns/route53 v1.5.0
|
||||
github.com/libp2p/go-netroute v0.2.1
|
||||
github.com/mdlayher/socket v0.5.1
|
||||
@@ -76,7 +77,7 @@ require (
|
||||
github.com/pion/transport/v3 v3.0.7
|
||||
github.com/pion/turn/v3 v3.0.1
|
||||
github.com/prometheus/client_golang v1.22.0
|
||||
github.com/quic-go/quic-go v0.48.2
|
||||
github.com/quic-go/quic-go v0.49.1
|
||||
github.com/redis/go-redis/v9 v9.7.3
|
||||
github.com/rs/xid v1.3.0
|
||||
github.com/shirou/gopsutil/v3 v3.24.4
|
||||
@@ -102,11 +103,12 @@ require (
|
||||
goauthentik.io/api/v3 v3.2023051.3
|
||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||
golang.org/x/mod v0.25.0
|
||||
golang.org/x/mod v0.26.0
|
||||
golang.org/x/net v0.42.0
|
||||
golang.org/x/oauth2 v0.28.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/sync v0.16.0
|
||||
golang.org/x/term v0.33.0
|
||||
golang.org/x/time v0.12.0
|
||||
google.golang.org/api v0.177.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/mysql v1.5.7
|
||||
@@ -146,7 +148,7 @@ require (
|
||||
github.com/beorn7/perks v1.0.1 // indirect
|
||||
github.com/caddyserver/zerossl v0.1.3 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/containerd/containerd v1.7.27 // indirect
|
||||
github.com/containerd/containerd v1.7.29 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/containerd/platforms v0.2.1 // indirect
|
||||
github.com/cpuguy83/dockercfg v0.3.2 // indirect
|
||||
@@ -183,7 +185,6 @@ require (
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
github.com/jackc/puddle/v2 v2.2.1 // indirect
|
||||
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
@@ -241,11 +242,10 @@ require (
|
||||
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.51.0 // indirect
|
||||
go.opentelemetry.io/otel/sdk v1.35.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.35.0 // indirect
|
||||
go.uber.org/mock v0.4.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.18.0 // indirect
|
||||
golang.org/x/text v0.27.0 // indirect
|
||||
golang.org/x/time v0.5.0 // indirect
|
||||
golang.org/x/tools v0.34.0 // indirect
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||
|
||||
24
go.sum
24
go.sum
@@ -142,8 +142,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht
|
||||
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
|
||||
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
|
||||
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
|
||||
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
|
||||
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
|
||||
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
|
||||
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
|
||||
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
|
||||
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
|
||||
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
|
||||
@@ -590,8 +590,8 @@ github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ
|
||||
github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I=
|
||||
github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc=
|
||||
github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk=
|
||||
github.com/quic-go/quic-go v0.48.2 h1:wsKXZPeGWpMpCGSWqOcqpW2wZYic/8T3aqiOID0/KWE=
|
||||
github.com/quic-go/quic-go v0.48.2/go.mod h1:yBgs3rWBOADpga7F+jJsb6Ybg1LSYiQvwWlLX+/6HMs=
|
||||
github.com/quic-go/quic-go v0.49.1 h1:e5JXpUyF0f2uFjckQzD8jTghZrOUK1xxDqqZhlwixo0=
|
||||
github.com/quic-go/quic-go v0.49.1/go.mod h1:s2wDnmCdooUQBmQfpUSTCYBl1/D4FcqbULMMkASvR6s=
|
||||
github.com/redis/go-redis/v9 v9.7.3 h1:YpPyAayJV+XErNsatSElgRZZVCwXX9QzkKYNvO7x0wM=
|
||||
github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRltRTZHSvrA=
|
||||
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
|
||||
@@ -749,8 +749,8 @@ go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v8
|
||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU=
|
||||
go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU=
|
||||
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
|
||||
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
@@ -818,8 +818,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
@@ -880,8 +880,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
|
||||
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
|
||||
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
|
||||
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
|
||||
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -993,8 +993,8 @@ golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
|
||||
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
||||
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
|
||||
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
|
||||
|
||||
@@ -1,11 +1,19 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
"log"
|
||||
"net/http"
|
||||
// nolint:gosec
|
||||
_ "net/http/pprof"
|
||||
"os"
|
||||
|
||||
"github.com/netbirdio/netbird/management/cmd"
|
||||
)
|
||||
|
||||
func main() {
|
||||
go func() {
|
||||
log.Println(http.ListenAndServe("localhost:6060", nil))
|
||||
}()
|
||||
if err := cmd.Execute(); err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -53,6 +53,9 @@ const (
|
||||
peerSchedulerRetryInterval = 3 * time.Second
|
||||
emptyUserID = "empty user ID in claims"
|
||||
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
|
||||
|
||||
envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||
envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||
)
|
||||
|
||||
type userLoggedInOnce bool
|
||||
@@ -109,6 +112,11 @@ type DefaultAccountManager struct {
|
||||
loginFilter *loginFilter
|
||||
|
||||
disableDefaultPolicy bool
|
||||
|
||||
holder *types.Holder
|
||||
|
||||
expNewNetworkMap bool
|
||||
expNewNetworkMapAIDs map[string]struct{}
|
||||
}
|
||||
|
||||
func isUniqueConstraintError(err error) bool {
|
||||
@@ -196,6 +204,18 @@ func BuildManager(
|
||||
log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
|
||||
}()
|
||||
|
||||
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err)
|
||||
newNetworkMapBuilder = false
|
||||
}
|
||||
|
||||
ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",")
|
||||
expIDs := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
expIDs[id] = struct{}{}
|
||||
}
|
||||
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
geo: geo,
|
||||
@@ -217,6 +237,10 @@ func BuildManager(
|
||||
permissionsManager: permissionsManager,
|
||||
loginFilter: newLoginFilter(),
|
||||
disableDefaultPolicy: disableDefaultPolicy,
|
||||
holder: types.NewHolder(),
|
||||
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
expNewNetworkMapAIDs: expIDs,
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -395,6 +419,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
}
|
||||
|
||||
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -1477,6 +1504,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
|
||||
}
|
||||
|
||||
if removedGroupAffectsPeers || newGroupsAffectsPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
|
||||
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
|
||||
}
|
||||
@@ -1641,11 +1672,6 @@ func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) boo
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
|
||||
@@ -2129,6 +2155,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateNetworkMap {
|
||||
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -128,4 +128,5 @@ type Manager interface {
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
SetEphemeralManager(em ephemeral.Manager)
|
||||
AllowSync(string, uint64) bool
|
||||
RecalculateNetworkMapCache(ctx context.Context, accountId string) error
|
||||
}
|
||||
|
||||
@@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SaveGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_SavePolicy(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeletePeer(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
group := types.Group{
|
||||
@@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
testAccountManager_NetworkUpdates_DeleteGroup(t)
|
||||
}
|
||||
|
||||
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
@@ -1377,6 +1422,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
for drained := false; !drained; {
|
||||
select {
|
||||
case <-updMsg:
|
||||
default:
|
||||
drained = true
|
||||
}
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -1736,7 +1789,9 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Address: "172.12.6.1/24",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
err := hasNilField(account)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/postgres"
|
||||
@@ -273,15 +274,21 @@ func configureConnectionPool(db *gorm.DB, storeEngine types.Engine) (*gorm.DB, e
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
} else {
|
||||
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
|
||||
if err != nil {
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
sqlDB.SetMaxOpenConns(conns)
|
||||
conns, err := strconv.Atoi(os.Getenv(sqlMaxOpenConnsEnv))
|
||||
if err != nil {
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
conns = 1
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(conns)
|
||||
sqlDB.SetMaxIdleConns(conns)
|
||||
sqlDB.SetConnMaxLifetime(time.Hour)
|
||||
sqlDB.SetConnMaxIdleTime(3 * time.Minute)
|
||||
|
||||
log.Infof("Set max open db connections to %d, max idle to %d, max lifetime to %v, max idle time to %v",
|
||||
conns, conns, time.Hour, 3*time.Minute)
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
@@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -138,6 +141,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
return err
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
|
||||
@@ -157,11 +165,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
}
|
||||
|
||||
newGroup.AccountID = accountID
|
||||
|
||||
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
|
||||
eventsToStore = append(eventsToStore, events...)
|
||||
|
||||
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -335,6 +347,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
|
||||
if err == nil && oldGroup != nil {
|
||||
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
|
||||
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
|
||||
|
||||
if oldGroup.Name != newGroup.Name {
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
meta := map[string]any{
|
||||
"old_name": oldGroup.Name,
|
||||
"new_name": newGroup.Name,
|
||||
}
|
||||
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta)
|
||||
})
|
||||
}
|
||||
} else {
|
||||
addedPeers = append(addedPeers, newGroup.Peers...)
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
@@ -471,6 +493,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -509,6 +534,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -537,6 +565,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -575,6 +606,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
|
||||
}
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -7,8 +7,10 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
@@ -44,6 +46,9 @@ import (
|
||||
const (
|
||||
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
|
||||
envBlockPeers = "NB_BLOCK_SAME_PEERS"
|
||||
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
|
||||
|
||||
defaultSyncLim = 1000
|
||||
)
|
||||
|
||||
// GRPCServer an instance of a Management gRPC API server
|
||||
@@ -63,6 +68,9 @@ type GRPCServer struct {
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
syncSem atomic.Int32
|
||||
syncLim int32
|
||||
}
|
||||
|
||||
// NewServer creates a new Management server
|
||||
@@ -96,6 +104,17 @@ func NewServer(
|
||||
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
|
||||
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
|
||||
|
||||
syncLim := int32(defaultSyncLim)
|
||||
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
|
||||
syncLimParsed, err := strconv.Atoi(syncLimStr)
|
||||
if err != nil {
|
||||
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
|
||||
} else {
|
||||
//nolint:gosec
|
||||
syncLim = int32(syncLimParsed)
|
||||
}
|
||||
}
|
||||
|
||||
return &GRPCServer{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
@@ -110,6 +129,8 @@ func NewServer(
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
|
||||
syncLim: syncLim,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -151,6 +172,11 @@ func getRealIP(ctx context.Context) net.IP {
|
||||
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
|
||||
// notifies the connected peer of any updates (e.g. new peers under the same account)
|
||||
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
|
||||
if s.syncSem.Load() >= s.syncLim {
|
||||
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
|
||||
}
|
||||
s.syncSem.Add(1)
|
||||
|
||||
reqStart := time.Now()
|
||||
|
||||
ctx := srv.Context()
|
||||
@@ -158,6 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
syncReq := &proto.SyncRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, syncReq)
|
||||
if err != nil {
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
realIP := getRealIP(ctx)
|
||||
@@ -172,6 +199,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
|
||||
}
|
||||
if s.blockPeersWithSameConfig {
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
|
||||
}
|
||||
}
|
||||
@@ -183,27 +211,34 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
|
||||
if err != nil {
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
|
||||
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
|
||||
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
|
||||
s.syncSem.Add(-1)
|
||||
return status.Errorf(codes.PermissionDenied, "peer is not registered")
|
||||
}
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
// nolint:staticcheck
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
start := time.Now()
|
||||
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
}
|
||||
}()
|
||||
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
|
||||
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
@@ -213,21 +248,32 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart))
|
||||
|
||||
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
return err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart))
|
||||
|
||||
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart))
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(ctx, peer)
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart))
|
||||
|
||||
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart))
|
||||
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
@@ -237,6 +283,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
|
||||
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
|
||||
|
||||
s.syncSem.Add(-1)
|
||||
|
||||
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
|
||||
}
|
||||
|
||||
@@ -509,10 +557,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
|
||||
|
||||
defer func() {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
|
||||
}
|
||||
took := time.Since(reqStart)
|
||||
if took > 7*time.Second {
|
||||
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
|
||||
}
|
||||
}()
|
||||
|
||||
if loginReq.GetMeta() == nil {
|
||||
@@ -546,9 +600,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
|
||||
|
||||
// if the login request contains setup key then it is a registration request
|
||||
if loginReq.GetSetupKey() != "" {
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
@@ -557,6 +614,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||
@@ -822,10 +881,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
sendStart := time.Now()
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)
|
||||
|
||||
39
management/server/holder.go
Normal file
39
management/server/holder.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) {
|
||||
a := am.holder.GetAccount(account.Id)
|
||||
if a == nil {
|
||||
am.holder.AddAccount(account)
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache = a.NetworkMapCache
|
||||
if account.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
account.NetworkMapCache.UpdateAccountPointer(account)
|
||||
am.holder.AddAccount(account)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account {
|
||||
return am.holder.GetAccount(accountID)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account {
|
||||
a := am.holder.GetAccount(accountID)
|
||||
if a != nil {
|
||||
return a
|
||||
}
|
||||
account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return account
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) {
|
||||
am.holder.AddAccount(account)
|
||||
}
|
||||
@@ -4,9 +4,13 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
@@ -38,7 +42,12 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const apiPrefix = "/api"
|
||||
const (
|
||||
apiPrefix = "/api"
|
||||
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
|
||||
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
|
||||
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
|
||||
)
|
||||
|
||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||
func NewAPIHandler(
|
||||
@@ -58,11 +67,42 @@ func NewAPIHandler(
|
||||
settingsManager settings.Manager,
|
||||
) (http.Handler, error) {
|
||||
|
||||
var rateLimitingConfig *middleware.RateLimiterConfig
|
||||
if os.Getenv(rateLimitingEnabledKey) == "true" {
|
||||
rpm := 6
|
||||
if v := os.Getenv(rateLimitingRPMKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
|
||||
} else {
|
||||
rpm = value
|
||||
}
|
||||
}
|
||||
|
||||
burst := 500
|
||||
if v := os.Getenv(rateLimitingBurstKey); v != "" {
|
||||
value, err := strconv.Atoi(v)
|
||||
if err != nil {
|
||||
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
|
||||
} else {
|
||||
burst = value
|
||||
}
|
||||
}
|
||||
|
||||
rateLimitingConfig = &middleware.RateLimiterConfig{
|
||||
RequestsPerMinute: float64(rpm),
|
||||
Burst: burst,
|
||||
CleanupInterval: 6 * time.Hour,
|
||||
LimiterTTL: 24 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
authManager,
|
||||
accountManager.GetAccountIDFromUserAuth,
|
||||
accountManager.SyncUserJWTGroups,
|
||||
accountManager.GetUserFromUserAuth,
|
||||
rateLimitingConfig,
|
||||
)
|
||||
|
||||
corsMiddleware := cors.AllowAll()
|
||||
|
||||
@@ -29,6 +29,7 @@ type AuthMiddleware struct {
|
||||
ensureAccount EnsureAccountFunc
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||
rateLimiter *APIRateLimiter
|
||||
}
|
||||
|
||||
// NewAuthMiddleware instance constructor
|
||||
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
|
||||
ensureAccount EnsureAccountFunc,
|
||||
syncUserJWTGroups SyncUserJWTGroupsFunc,
|
||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||
rateLimiterConfig *RateLimiterConfig,
|
||||
) *AuthMiddleware {
|
||||
var rateLimiter *APIRateLimiter
|
||||
if rateLimiterConfig != nil {
|
||||
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
|
||||
}
|
||||
|
||||
return &AuthMiddleware{
|
||||
authManager: authManager,
|
||||
ensureAccount: ensureAccount,
|
||||
syncUserJWTGroups: syncUserJWTGroups,
|
||||
getUserFromUserAuth: getUserFromUserAuth,
|
||||
rateLimiter: rateLimiter,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||
request, err := m.checkPATFromRequest(r, auth)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||
// Check if it's a status error, otherwise default to Unauthorized
|
||||
if _, ok := status.FromError(err); !ok {
|
||||
err = status.Errorf(status.Unauthorized, "token invalid")
|
||||
}
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
h.ServeHTTP(w, request)
|
||||
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
|
||||
return r, fmt.Errorf("error extracting token: %w", err)
|
||||
}
|
||||
|
||||
if m.rateLimiter != nil {
|
||||
if !m.rateLimiter.Allow(token) {
|
||||
return r, status.Errorf(status.TooManyRequests, "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
|
||||
if err != nil {
|
||||
|
||||
@@ -27,7 +27,9 @@ const (
|
||||
domainCategory = "domainCategory"
|
||||
userID = "userID"
|
||||
tokenID = "tokenID"
|
||||
tokenID2 = "tokenID2"
|
||||
PAT = "nbp_PAT"
|
||||
PAT2 = "nbp_PAT2"
|
||||
JWT = "JWT"
|
||||
wrongToken = "wrongToken"
|
||||
)
|
||||
@@ -49,6 +51,15 @@ var testAccount = &types.Account{
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
},
|
||||
tokenID2: {
|
||||
ID: tokenID2,
|
||||
Name: "My second token",
|
||||
HashedToken: "someHash2",
|
||||
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
|
||||
CreatedBy: userID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
LastUsed: util.ToPtr(time.Now().UTC()),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
|
||||
if token == PAT {
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
if token == PAT2 {
|
||||
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
|
||||
}
|
||||
return nil, nil, "", "", fmt.Errorf("PAT invalid")
|
||||
}
|
||||
|
||||
@@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
|
||||
}
|
||||
|
||||
func mockMarkPATUsed(_ context.Context, token string) error {
|
||||
if token == tokenID {
|
||||
if token == tokenID || token == tokenID2 {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("Should never get reached")
|
||||
@@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||
@@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
|
||||
// Configure rate limiter: 10 requests per minute with burst of 5
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 10,
|
||||
Burst: 5,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make burst requests - all should succeed
|
||||
successCount := 0
|
||||
for i := 0; i < 5; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 5, successCount, "All burst requests should succeed")
|
||||
|
||||
// The 6th request should fail (exceeded burst)
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
|
||||
// Configure very low rate limit: 1 request per minute
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request should succeed
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
// Second request should fail (rate limited)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
})
|
||||
|
||||
t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
|
||||
// Configure strict rate limit
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Make multiple requests with Bearer token - all should succeed
|
||||
successCount := 0
|
||||
for i := 0; i < 10; i++ {
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+JWT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rec, req)
|
||||
if rec.Code == http.StatusOK {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
|
||||
})
|
||||
|
||||
t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
|
||||
// Configure rate limiter
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 1,
|
||||
Burst: 1,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// Use first PAT token
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
|
||||
|
||||
// Second request with same token should fail
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
|
||||
|
||||
// Use second PAT token - should succeed because it has independent rate limit
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT2)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
|
||||
|
||||
// Second request with PAT2 should also be rate limited
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT2)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
|
||||
|
||||
// JWT should still work (not rate limited)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+JWT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
|
||||
})
|
||||
|
||||
t.Run("Rate Limiter Cleanup", func(t *testing.T) {
|
||||
// Configure rate limiter with short cleanup interval and TTL for testing
|
||||
rateLimitConfig := &RateLimiterConfig{
|
||||
RequestsPerMinute: 60,
|
||||
Burst: 1,
|
||||
CleanupInterval: 100 * time.Millisecond,
|
||||
LimiterTTL: 200 * time.Millisecond,
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
rateLimitConfig,
|
||||
)
|
||||
|
||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
// First request - should succeed
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
|
||||
|
||||
// Second request immediately - should fail (burst exhausted)
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
|
||||
|
||||
// Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
|
||||
time.Sleep(400 * time.Millisecond)
|
||||
|
||||
// After cleanup, the limiter should be removed and recreated with full burst capacity
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
|
||||
|
||||
// Verify it's a fresh limiter by checking burst is reset
|
||||
req = httptest.NewRequest("GET", "http://testing/test", nil)
|
||||
req.Header.Set("Authorization", "Token "+PAT)
|
||||
rec = httptest.NewRecorder()
|
||||
handler.ServeHTTP(rec, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
||||
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
for _, tc := range tt {
|
||||
|
||||
146
management/server/http/middleware/rate_limiter.go
Normal file
146
management/server/http/middleware/rate_limiter.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/time/rate"
|
||||
)
|
||||
|
||||
// RateLimiterConfig holds configuration for the API rate limiter
|
||||
type RateLimiterConfig struct {
|
||||
// RequestsPerMinute defines the rate at which tokens are replenished
|
||||
RequestsPerMinute float64
|
||||
// Burst defines the maximum number of requests that can be made in a burst
|
||||
Burst int
|
||||
// CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
|
||||
CleanupInterval time.Duration
|
||||
// LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
|
||||
LimiterTTL time.Duration
|
||||
}
|
||||
|
||||
// DefaultRateLimiterConfig returns a default configuration
|
||||
func DefaultRateLimiterConfig() *RateLimiterConfig {
|
||||
return &RateLimiterConfig{
|
||||
RequestsPerMinute: 100,
|
||||
Burst: 120,
|
||||
CleanupInterval: 5 * time.Minute,
|
||||
LimiterTTL: 10 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// limiterEntry holds a rate limiter and its last access time
|
||||
type limiterEntry struct {
|
||||
limiter *rate.Limiter
|
||||
lastAccess time.Time
|
||||
}
|
||||
|
||||
// APIRateLimiter manages rate limiting for API tokens
|
||||
type APIRateLimiter struct {
|
||||
config *RateLimiterConfig
|
||||
limiters map[string]*limiterEntry
|
||||
mu sync.RWMutex
|
||||
stopChan chan struct{}
|
||||
}
|
||||
|
||||
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
|
||||
func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
|
||||
if config == nil {
|
||||
config = DefaultRateLimiterConfig()
|
||||
}
|
||||
|
||||
rl := &APIRateLimiter{
|
||||
config: config,
|
||||
limiters: make(map[string]*limiterEntry),
|
||||
stopChan: make(chan struct{}),
|
||||
}
|
||||
|
||||
go rl.cleanupLoop()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Allow checks if a request for the given key (token) is allowed
|
||||
func (rl *APIRateLimiter) Allow(key string) bool {
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Allow()
|
||||
}
|
||||
|
||||
// Wait blocks until the rate limiter allows another request for the given key
|
||||
// Returns an error if the context is canceled
|
||||
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
|
||||
limiter := rl.getLimiter(key)
|
||||
return limiter.Wait(ctx)
|
||||
}
|
||||
|
||||
// getLimiter retrieves or creates a rate limiter for the given key
|
||||
func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
|
||||
rl.mu.RLock()
|
||||
entry, exists := rl.limiters[key]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
rl.mu.Lock()
|
||||
entry.lastAccess = time.Now()
|
||||
rl.mu.Unlock()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
if entry, exists := rl.limiters[key]; exists {
|
||||
entry.lastAccess = time.Now()
|
||||
return entry.limiter
|
||||
}
|
||||
|
||||
requestsPerSecond := rl.config.RequestsPerMinute / 60.0
|
||||
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
|
||||
rl.limiters[key] = &limiterEntry{
|
||||
limiter: limiter,
|
||||
lastAccess: time.Now(),
|
||||
}
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// cleanupLoop periodically removes old limiters that haven't been used recently
|
||||
func (rl *APIRateLimiter) cleanupLoop() {
|
||||
ticker := time.NewTicker(rl.config.CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
rl.cleanup()
|
||||
case <-rl.stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanup removes limiters that haven't been used within the TTL period
|
||||
func (rl *APIRateLimiter) cleanup() {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for key, entry := range rl.limiters {
|
||||
if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops the cleanup goroutine
|
||||
func (rl *APIRateLimiter) Stop() {
|
||||
close(rl.stopChan)
|
||||
}
|
||||
|
||||
// Reset removes the rate limiter for a specific key
|
||||
func (rl *APIRateLimiter) Reset(key string) {
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
delete(rl.limiters, key)
|
||||
}
|
||||
@@ -125,9 +125,10 @@ type MockAccountManager struct {
|
||||
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
AllowSyncFunc func(string, uint64) bool
|
||||
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
|
||||
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
|
||||
@@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
|
||||
if am.RecalculateNetworkMapCacheFunc != nil {
|
||||
return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
|
||||
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
|
||||
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
137
management/server/networkmap.go
Normal file
137
management/server/networkmap.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
am.enrichAccountFromHolder(account)
|
||||
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getPeerNetworkMapExp(
|
||||
ctx context.Context,
|
||||
accountId string,
|
||||
peerId string,
|
||||
validatedPeers map[string]struct{},
|
||||
customZone nbdns.CustomZone,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
resourcePolicies map[string][]*types.Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
) *types.NetworkMap {
|
||||
account := am.getAccountFromHolderOrInit(accountId)
|
||||
if account == nil {
|
||||
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||
return &types.NetworkMap{
|
||||
Network: &types.Network{},
|
||||
}
|
||||
}
|
||||
|
||||
legacyMap := account.GetPeerNetworkMap(ctx, peerId, customZone, validatedPeers, resourcePolicies, routers, nil)
|
||||
|
||||
go func() {
|
||||
expMap := account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
|
||||
am.compareAndSaveNetworkMaps(ctx, accountId, peerId, expMap, legacyMap)
|
||||
}()
|
||||
|
||||
return legacyMap
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) compareAndSaveNetworkMaps(ctx context.Context, accountId, peerId string, expMap, legacyMap *types.NetworkMap) {
|
||||
expBytes, err := json.Marshal(expMap)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to marshal experimental network map: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
legacyBytes, err := json.Marshal(legacyMap)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to marshal legacy network map: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(expBytes) == len(legacyBytes) {
|
||||
log.WithContext(ctx).Debugf("network maps are equal for peer %s in account %s (size: %d bytes)", peerId, accountId, len(expBytes))
|
||||
return
|
||||
}
|
||||
|
||||
timestamp := time.Now().UnixMicro()
|
||||
baseDir := filepath.Join("debug_networkmaps", accountId, peerId)
|
||||
|
||||
if err := os.MkdirAll(baseDir, 0o755); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to create debug directory %s: %v", baseDir, err)
|
||||
return
|
||||
}
|
||||
|
||||
expFile := filepath.Join(baseDir, fmt.Sprintf("exp_networkmap_%d.json", timestamp))
|
||||
if err := os.WriteFile(expFile, expBytes, 0o644); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to write experimental network map to %s: %v", expFile, err)
|
||||
return
|
||||
}
|
||||
|
||||
legacyFile := filepath.Join(baseDir, fmt.Sprintf("legacy_networkmap_%d.json", timestamp))
|
||||
if err := os.WriteFile(legacyFile, legacyBytes, 0o644); err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to write legacy network map to %s: %v", legacyFile, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("network maps differ for peer %s in account %s - saved to %s (exp: %d bytes, legacy: %d bytes)", peerId, accountId, baseDir, len(expBytes), len(legacyBytes))
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
am.enrichAccountFromHolder(account)
|
||||
return account.OnPeerAddedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||
am.enrichAccountFromHolder(account)
|
||||
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||
account := am.getAccountFromHolder(accountId)
|
||||
if account == nil {
|
||||
return
|
||||
}
|
||||
account.UpdatePeerInNetworkMapCache(peer)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||
account.RecalculateNetworkMapCache(validatedPeers)
|
||||
am.updateAccountInHolder(account)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||
if am.experimentalNetworkMap(accountId) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||
return err
|
||||
}
|
||||
am.recalculateNetworkMapCache(account, validatedPeers)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool {
|
||||
_, ok := am.expNewNetworkMapAIDs[accountId]
|
||||
return am.expNewNetworkMap || ok
|
||||
}
|
||||
@@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
|
||||
return resource, nil
|
||||
@@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
|
||||
return resource, nil
|
||||
@@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
event()
|
||||
}
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
|
||||
return router, nil
|
||||
@@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
|
||||
return router, nil
|
||||
@@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
|
||||
event()
|
||||
|
||||
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
|
||||
@@ -106,11 +106,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
|
||||
|
||||
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
|
||||
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var settings *types.Settings
|
||||
var expired bool
|
||||
@@ -145,6 +140,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
|
||||
}
|
||||
|
||||
if expired {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
|
||||
// the expired one. Here we notify them that connection is now allowed again.
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
@@ -321,6 +319,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
|
||||
}
|
||||
}
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
|
||||
if peerLabelChanged || requiresPeerUpdates {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
} else if sshChanged {
|
||||
@@ -381,6 +383,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
|
||||
storeEvent()
|
||||
}
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if userID != activity.SystemInitiator {
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
@@ -417,7 +431,16 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(peer.AccountID) {
|
||||
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil, resourcePolicies, routers)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil)
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -690,6 +713,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
|
||||
|
||||
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
|
||||
@@ -708,11 +742,6 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
|
||||
|
||||
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
|
||||
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start))
|
||||
}()
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var peerNotValid bool
|
||||
var isStatusChanged bool
|
||||
@@ -776,6 +805,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -831,6 +863,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
startTransaction := time.Now()
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
@@ -900,8 +933,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
|
||||
|
||||
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
|
||||
}
|
||||
startBuffer := time.Now()
|
||||
am.BufferUpdateAccountPeers(ctx, accountID)
|
||||
log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer))
|
||||
}
|
||||
|
||||
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
|
||||
@@ -997,11 +1037,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
|
||||
}()
|
||||
|
||||
if isRequiresApproval {
|
||||
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
@@ -1014,9 +1049,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return peer, emptyMap, nil, nil
|
||||
}
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account = am.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||
@@ -1024,10 +1067,12 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
|
||||
|
||||
@@ -1037,7 +1082,16 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
var networkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics(), resourcePolicies, routers)
|
||||
} else {
|
||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -1167,11 +1221,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
var (
|
||||
account *types.Account
|
||||
err error
|
||||
)
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
account = am.getAccountFromHolderOrInit(accountID)
|
||||
} else {
|
||||
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
globalStart := time.Now()
|
||||
@@ -1204,6 +1265,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||
}
|
||||
|
||||
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||
@@ -1241,7 +1306,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountID) {
|
||||
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics(), resourcePolicies, routers)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||
start = time.Now()
|
||||
@@ -1257,7 +1328,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
|
||||
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
@@ -1351,7 +1422,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
return
|
||||
}
|
||||
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
var remotePeerNetworkMap *types.NetworkMap
|
||||
|
||||
if am.experimentalNetworkMap(accountId) {
|
||||
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics(), resourcePolicies, routers)
|
||||
} else {
|
||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
|
||||
}
|
||||
|
||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||
if ok {
|
||||
@@ -1368,7 +1445,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
|
||||
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
|
||||
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
}
|
||||
|
||||
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
|
||||
@@ -1580,7 +1657,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
|
||||
},
|
||||
},
|
||||
},
|
||||
NetworkMap: &types.NetworkMap{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
peerDeletedEvents = append(peerDeletedEvents, func() {
|
||||
|
||||
@@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap(t *testing.T) {
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testGetNetworkMapGeneral(t)
|
||||
}
|
||||
|
||||
func testGetNetworkMapGeneral(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers_Experimental(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func TestUpdateAccountPeers(t *testing.T) {
|
||||
testUpdateAccountPeers(t)
|
||||
}
|
||||
|
||||
func testUpdateAccountPeers(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
@@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) {
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
assert.Nil(t, update.Update.NetbirdConfig)
|
||||
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
|
||||
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
|
||||
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
}
|
||||
|
||||
func Test_LoginPeer(t *testing.T) {
|
||||
t.Setenv(envNewNetworkMapBuilder, "true")
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
}
|
||||
|
||||
@@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
|
||||
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
|
||||
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
@@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.254.139",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.250.202",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: types.FirewallRuleDirectionIN,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
Port: "",
|
||||
PolicyID: "RuleDefault",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "all",
|
||||
@@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
|
||||
// We expect a single permissive firewall rule which all outgoing connections
|
||||
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
|
||||
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
|
||||
assert.Len(t, firewallRules, 1)
|
||||
assert.Len(t, firewallRules, 7)
|
||||
expectedFirewallRules := []*types.FirewallRule{
|
||||
{
|
||||
PeerIP: "0.0.0.0",
|
||||
PeerIP: "100.65.80.39",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.14.88",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.62.5",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.32.206",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.13.186",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.29.55",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
Port: "80",
|
||||
PolicyID: "RuleSwarm",
|
||||
},
|
||||
{
|
||||
PeerIP: "100.65.21.56",
|
||||
Direction: types.FirewallRuleDirectionOUT,
|
||||
Action: "accept",
|
||||
Protocol: "tcp",
|
||||
|
||||
@@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
@@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user