mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-26 20:26:39 +00:00
Compare commits
23 Commits
v0.47.2
...
test/multi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7ababbf45 | ||
|
|
23b5d45b68 | ||
|
|
e993b633bd | ||
|
|
6aaec1002d | ||
|
|
ebf3d26c91 | ||
|
|
0e5dc9d412 | ||
|
|
91f7ee6a3c | ||
|
|
7c6b85b4cb | ||
|
|
9dc9402deb | ||
|
|
41a9e45c68 | ||
|
|
641891e931 | ||
|
|
c43ddddcdb | ||
|
|
0a9d09267a | ||
|
|
05733b00c1 | ||
|
|
0a5f751343 | ||
|
|
b2a7a4c6d4 | ||
|
|
cfdaa82fea | ||
|
|
c332ff0a47 | ||
|
|
6cd77cc17c | ||
|
|
19835dc6d5 | ||
|
|
3cd21cc7e5 | ||
|
|
4619d39e17 | ||
|
|
5b09804a17 |
@@ -4,12 +4,12 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
// Preferences export a subset of the internal config for gomobile
|
||||
// Preferences exports a subset of the internal config for gomobile
|
||||
type Preferences struct {
|
||||
configInput internal.ConfigInput
|
||||
}
|
||||
|
||||
// NewPreferences create new Preferences instance
|
||||
// NewPreferences creates a new Preferences instance
|
||||
func NewPreferences(configPath string) *Preferences {
|
||||
ci := internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
||||
return &Preferences{ci}
|
||||
}
|
||||
|
||||
// GetManagementURL read url from config file
|
||||
// GetManagementURL reads URL from config file
|
||||
func (p *Preferences) GetManagementURL() (string, error) {
|
||||
if p.configInput.ManagementURL != "" {
|
||||
return p.configInput.ManagementURL, nil
|
||||
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
||||
return cfg.ManagementURL.String(), err
|
||||
}
|
||||
|
||||
// SetManagementURL store the given url and wait for commit
|
||||
// SetManagementURL stores the given URL and waits for commit
|
||||
func (p *Preferences) SetManagementURL(url string) {
|
||||
p.configInput.ManagementURL = url
|
||||
}
|
||||
|
||||
// GetAdminURL read url from config file
|
||||
// GetAdminURL reads URL from config file
|
||||
func (p *Preferences) GetAdminURL() (string, error) {
|
||||
if p.configInput.AdminURL != "" {
|
||||
return p.configInput.AdminURL, nil
|
||||
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
||||
return cfg.AdminURL.String(), err
|
||||
}
|
||||
|
||||
// SetAdminURL store the given url and wait for commit
|
||||
// SetAdminURL stores the given URL and waits for commit
|
||||
func (p *Preferences) SetAdminURL(url string) {
|
||||
p.configInput.AdminURL = url
|
||||
}
|
||||
|
||||
// GetPreSharedKey read preshared key from config file
|
||||
// GetPreSharedKey reads pre-shared key from config file
|
||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
if p.configInput.PreSharedKey != nil {
|
||||
return *p.configInput.PreSharedKey, nil
|
||||
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||
return cfg.PreSharedKey, err
|
||||
}
|
||||
|
||||
// SetPreSharedKey store the given key and wait for commit
|
||||
// SetPreSharedKey stores the given key and waits for commit
|
||||
func (p *Preferences) SetPreSharedKey(key string) {
|
||||
p.configInput.PreSharedKey = &key
|
||||
}
|
||||
|
||||
// SetRosenpassEnabled store if rosenpass is enabled
|
||||
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||
p.configInput.RosenpassEnabled = &enabled
|
||||
}
|
||||
|
||||
// GetRosenpassEnabled read rosenpass enabled from config file
|
||||
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
if p.configInput.RosenpassEnabled != nil {
|
||||
return *p.configInput.RosenpassEnabled, nil
|
||||
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||
return cfg.RosenpassEnabled, err
|
||||
}
|
||||
|
||||
// SetRosenpassPermissive store the given permissive and wait for commit
|
||||
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||
p.configInput.RosenpassPermissive = &permissive
|
||||
}
|
||||
|
||||
// GetRosenpassPermissive read rosenpass permissive from config file
|
||||
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
if p.configInput.RosenpassPermissive != nil {
|
||||
return *p.configInput.RosenpassPermissive, nil
|
||||
@@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||
return cfg.RosenpassPermissive, err
|
||||
}
|
||||
|
||||
// Commit write out the changes into config file
|
||||
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||
if p.configInput.DisableClientRoutes != nil {
|
||||
return *p.configInput.DisableClientRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableClientRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableClientRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||
p.configInput.DisableClientRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||
if p.configInput.DisableServerRoutes != nil {
|
||||
return *p.configInput.DisableServerRoutes, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableServerRoutes, err
|
||||
}
|
||||
|
||||
// SetDisableServerRoutes stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||
p.configInput.DisableServerRoutes = &disable
|
||||
}
|
||||
|
||||
// GetDisableDNS reads disable DNS setting from config file
|
||||
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||
if p.configInput.DisableDNS != nil {
|
||||
return *p.configInput.DisableDNS, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableDNS, err
|
||||
}
|
||||
|
||||
// SetDisableDNS stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||
p.configInput.DisableDNS = &disable
|
||||
}
|
||||
|
||||
// GetDisableFirewall reads disable firewall setting from config file
|
||||
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||
if p.configInput.DisableFirewall != nil {
|
||||
return *p.configInput.DisableFirewall, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.DisableFirewall, err
|
||||
}
|
||||
|
||||
// SetDisableFirewall stores the given value and waits for commit
|
||||
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||
p.configInput.DisableFirewall = &disable
|
||||
}
|
||||
|
||||
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||
if p.configInput.ServerSSHAllowed != nil {
|
||||
return *p.configInput.ServerSSHAllowed, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if cfg.ServerSSHAllowed == nil {
|
||||
// Default to false for security on Android
|
||||
return false, nil
|
||||
}
|
||||
return *cfg.ServerSSHAllowed, err
|
||||
}
|
||||
|
||||
// SetServerSSHAllowed stores the given value and waits for commit
|
||||
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||
p.configInput.ServerSSHAllowed = &allowed
|
||||
}
|
||||
|
||||
// GetBlockInbound reads block inbound setting from config file
|
||||
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||
if p.configInput.BlockInbound != nil {
|
||||
return *p.configInput.BlockInbound, nil
|
||||
}
|
||||
|
||||
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return cfg.BlockInbound, err
|
||||
}
|
||||
|
||||
// SetBlockInbound stores the given value and waits for commit
|
||||
func (p *Preferences) SetBlockInbound(block bool) {
|
||||
p.configInput.BlockInbound = &block
|
||||
}
|
||||
|
||||
// Commit writes out the changes to the config file
|
||||
func (p *Preferences) Commit() error {
|
||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||
return err
|
||||
|
||||
@@ -38,5 +38,5 @@ func init() {
|
||||
|
||||
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||
"This overrides any policies received from the management service.")
|
||||
"This overrides any policies received from the management service.")
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ type WGTunDevice struct {
|
||||
mtu int
|
||||
iceBind *bind.ICEBind
|
||||
tunAdapter TunAdapter
|
||||
disableDNS bool
|
||||
|
||||
name string
|
||||
device *device.Device
|
||||
@@ -32,7 +33,7 @@ type WGTunDevice struct {
|
||||
configurer WGConfigurer
|
||||
}
|
||||
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||
return &WGTunDevice{
|
||||
address: address,
|
||||
port: port,
|
||||
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
||||
mtu: mtu,
|
||||
iceBind: iceBind,
|
||||
tunAdapter: tunAdapter,
|
||||
disableDNS: disableDNS,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
||||
routesString := routesToString(routes)
|
||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||
|
||||
// Skip DNS configuration when DisableDNS is enabled
|
||||
if t.disableDNS {
|
||||
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
||||
dns = ""
|
||||
searchDomainsToString = ""
|
||||
}
|
||||
|
||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||
if err != nil {
|
||||
log.Errorf("failed to create Android interface: %s", err)
|
||||
|
||||
@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
|
||||
MobileArgs *device.MobileIFaceArguments
|
||||
TransportNet transport.Net
|
||||
FilterFn bind.FilterFn
|
||||
DisableDNS bool
|
||||
}
|
||||
|
||||
// WGIface represents an interface instance
|
||||
|
||||
@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
||||
|
||||
wgIFace := &WGIface{
|
||||
userspaceBind: true,
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||
}
|
||||
return wgIFace, nil
|
||||
|
||||
@@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
|
||||
//
|
||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||
if drop {
|
||||
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||
|
||||
if hasPortRestrictions {
|
||||
// Don't squash rules with port restrictions
|
||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := protocols[r.Protocol]; !ok {
|
||||
protocols[r.Protocol] = &protoMatch{
|
||||
ips: map[string]int{},
|
||||
|
||||
@@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||
}
|
||||
|
||||
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rules []*mgmProto.FirewallRule
|
||||
expectedCount int
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "should not squash rules with port ranges",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with specific ports",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with legacy port field",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with legacy port field should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should not squash rules with DROP action",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_DROP,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "Rules with DROP action should not be squashed",
|
||||
},
|
||||
{
|
||||
name: "should squash rules without port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 1,
|
||||
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||
},
|
||||
{
|
||||
name: "mixed rules should not squash protocol with port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
PortInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
},
|
||||
},
|
||||
expectedCount: 4,
|
||||
description: "TCP should not be squashed because one rule has port restrictions",
|
||||
},
|
||||
{
|
||||
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||
rules: []*mgmProto.FirewallRule{
|
||||
// TCP rules with port restrictions - should NOT be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_TCP,
|
||||
Port: "443",
|
||||
},
|
||||
// UDP rules without port restrictions - SHOULD be squashed
|
||||
{
|
||||
PeerIP: "10.93.0.1",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.2",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.3",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
{
|
||||
PeerIP: "10.93.0.4",
|
||||
Direction: mgmProto.RuleDirection_IN,
|
||||
Action: mgmProto.RuleAction_ACCEPT,
|
||||
Protocol: mgmProto.RuleProtocol_UDP,
|
||||
},
|
||||
},
|
||||
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||
{AllowedIps: []string{"10.93.0.1"}},
|
||||
{AllowedIps: []string{"10.93.0.2"}},
|
||||
{AllowedIps: []string{"10.93.0.3"}},
|
||||
{AllowedIps: []string{"10.93.0.4"}},
|
||||
},
|
||||
FirewallRules: tt.rules,
|
||||
}
|
||||
|
||||
manager := &DefaultManager{}
|
||||
rules, _ := manager.squashAcceptRules(networkMap)
|
||||
|
||||
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||
|
||||
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||
if tt.expectedCount == 1 {
|
||||
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
portInfo *mgmProto.PortInfo
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil PortInfo should be empty",
|
||||
portInfo: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero port should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 0,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with valid port should not be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Port{
|
||||
Port: 80,
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with nil range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: nil,
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero start range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 0,
|
||||
End: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with zero end range should be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 80,
|
||||
End: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "PortInfo with valid range should not be empty",
|
||||
portInfo: &mgmProto.PortInfo{
|
||||
PortSelection: &mgmProto.PortInfo_Range_{
|
||||
Range: &mgmProto.PortInfo_Range{
|
||||
Start: 8080,
|
||||
End: 8090,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := portInfoEmpty(tt.portInfo)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||
networkMap := &mgmProto.NetworkMap{
|
||||
PeerConfig: &mgmProto.PeerConfig{
|
||||
|
||||
@@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
config := &Config{
|
||||
// defaults to false only for new (post 0.26) configurations
|
||||
ServerSSHAllowed: util.False(),
|
||||
// default to disabling server routes on Android for security
|
||||
DisableServerRoutes: runtime.GOOS == "android",
|
||||
}
|
||||
|
||||
if _, err := config.apply(input); err != nil {
|
||||
@@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||
updated = true
|
||||
} else if config.ServerSSHAllowed == nil {
|
||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||
config.ServerSSHAllowed = util.True()
|
||||
if runtime.GOOS == "android" {
|
||||
// default to disabled SSH on Android for security
|
||||
log.Infof("setting SSH server to false by default on Android")
|
||||
config.ServerSSHAllowed = util.False()
|
||||
} else {
|
||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||
config.ServerSSHAllowed = util.True()
|
||||
}
|
||||
updated = true
|
||||
}
|
||||
|
||||
|
||||
@@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
|
||||
MTU: iface.DefaultMTU,
|
||||
TransportNet: transportNet,
|
||||
FilterFn: e.addrViaRoutes,
|
||||
DisableDNS: e.config.DisableDNS,
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
// MockManager is the mock instance of a route manager
|
||||
type MockManager struct {
|
||||
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
|
||||
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
|
||||
TriggerSelectionFunc func(haMap route.HAMap)
|
||||
GetRouteSelectorFunc func() *routeselector.RouteSelector
|
||||
GetClientRoutesFunc func() route.HAMap
|
||||
|
||||
@@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
|
||||
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
|
||||
nets := make([]string, 0)
|
||||
for _, r := range clientRoutes {
|
||||
// filter out domain routes
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
@@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
|
||||
if runtime.GOOS != "android" {
|
||||
return
|
||||
}
|
||||
newNets := make([]string, 0)
|
||||
|
||||
var newNets []string
|
||||
for _, routes := range idMap {
|
||||
for _, r := range routes {
|
||||
if r.IsDynamic() {
|
||||
continue
|
||||
}
|
||||
newNets = append(newNets, r.Network.String())
|
||||
}
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
|
||||
// OnNewPrefixes is called from iOS only
|
||||
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
newNets := make([]string, 0)
|
||||
for _, prefix := range prefixes {
|
||||
@@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
|
||||
}
|
||||
|
||||
sort.Strings(newNets)
|
||||
switch runtime.GOOS {
|
||||
case "android":
|
||||
if !n.hasDiff(n.initialRouteRanges, newNets) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
if !n.hasDiff(n.routeRanges, newNets) {
|
||||
return
|
||||
}
|
||||
|
||||
n.routeRanges = newNets
|
||||
|
||||
n.notify()
|
||||
}
|
||||
|
||||
|
||||
@@ -59,16 +59,16 @@ type Info struct {
|
||||
Environment Environment
|
||||
Files []File // for posture checks
|
||||
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
RosenpassEnabled bool
|
||||
RosenpassPermissive bool
|
||||
ServerSSHAllowed bool
|
||||
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
DisableClientRoutes bool
|
||||
DisableServerRoutes bool
|
||||
DisableDNS bool
|
||||
DisableFirewall bool
|
||||
BlockLANAccess bool
|
||||
BlockInbound bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
}
|
||||
|
||||
@@ -102,6 +102,8 @@ type DefaultAccountManager struct {
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
updateAccountPeersBufferInterval atomic.Int64
|
||||
|
||||
loginFilter *loginFilter
|
||||
}
|
||||
|
||||
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
|
||||
@@ -195,6 +197,7 @@ func BuildManager(
|
||||
proxyController: proxyController,
|
||||
settingsManager: settingsManager,
|
||||
permissionsManager: permissionsManager,
|
||||
loginFilter: newLoginFilter(),
|
||||
}
|
||||
|
||||
am.startWarmup(ctx)
|
||||
@@ -1536,6 +1539,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
|
||||
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) AllowSync(wgPubKey, metahash string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
@@ -1557,6 +1564,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
|
||||
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
metahash := metaHash(meta, realIP.String())
|
||||
am.loginFilter.addLogin(peerPubKey, metahash)
|
||||
|
||||
return peer, netMap, postureChecks, nil
|
||||
}
|
||||
|
||||
@@ -1570,7 +1580,6 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
@@ -117,4 +117,5 @@ type Manager interface {
|
||||
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
|
||||
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
|
||||
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
|
||||
AllowSync(string, string) bool
|
||||
}
|
||||
|
||||
@@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
|
||||
for _, groupID := range groupIDs {
|
||||
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
|
||||
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
|
||||
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)
|
||||
|
||||
@@ -166,7 +166,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
|
||||
realIP := getRealIP(ctx)
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String())
|
||||
sip := realIP.String()
|
||||
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s] [%s]", req.WgPubKey, sip, metaHash(extractPeerMeta(ctx, syncReq.GetMeta()), sip))
|
||||
|
||||
if syncReq.GetMeta() == nil {
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
|
||||
113
management/server/loginfilter.go
Normal file
113
management/server/loginfilter.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
loginFilterSize = 100_000 // Size of the login filter map, making it large enough for a future
|
||||
filterTimeout = 5 * time.Minute // Duration to secure the previous login information in the filter
|
||||
|
||||
loggingLimit = 100
|
||||
|
||||
loggingLimitOnePeer = 30
|
||||
loggingTresholdOnePeer = 5 * time.Minute
|
||||
)
|
||||
|
||||
type loginFilter struct {
|
||||
mu sync.RWMutex
|
||||
logged map[string]metahash
|
||||
}
|
||||
|
||||
type metahash struct {
|
||||
hashes map[string]struct{}
|
||||
counter int
|
||||
start time.Time
|
||||
}
|
||||
|
||||
func newLoginFilter() *loginFilter {
|
||||
return &loginFilter{
|
||||
logged: make(map[string]metahash, loginFilterSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *loginFilter) addLogin(wgPubKey, metaHash string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
mh, ok := l.logged[wgPubKey]
|
||||
if !ok {
|
||||
mh = metahash{
|
||||
hashes: make(map[string]struct{}, loggingLimit),
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
mh.hashes[metaHash] = struct{}{}
|
||||
mh.counter++
|
||||
if mh.counter >= loggingLimit && mh.counter%loggingLimit == 0 && len(mh.hashes) > 1 {
|
||||
log.WithFields(log.Fields{
|
||||
"wgPubKey": wgPubKey,
|
||||
"number of different hashes": len(mh.hashes),
|
||||
"elapsed time for number of attempts": time.Since(mh.start),
|
||||
"number of syncs": mh.counter,
|
||||
}).Info(mh.prepareHashes())
|
||||
} else if mh.counter%loggingLimitOnePeer == 0 && time.Since(mh.start) > loggingTresholdOnePeer && len(mh.hashes) == 1 {
|
||||
log.WithFields(log.Fields{
|
||||
"wgPubKey": wgPubKey,
|
||||
"elapsed time for number of attempts": time.Since(mh.start),
|
||||
"number of syncs": mh.counter,
|
||||
}).Info(mh.prepareHashes())
|
||||
mh.start = time.Now()
|
||||
}
|
||||
l.logged[wgPubKey] = mh
|
||||
}
|
||||
|
||||
func (m *metahash) prepareHashes() string {
|
||||
var sb strings.Builder
|
||||
for hash := range m.hashes {
|
||||
sb.WriteString(hash)
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
|
||||
b.WriteString(meta.WtVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.OSVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.KernelVersion)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(mac)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
@@ -119,6 +119,8 @@ type MockAccountManager struct {
|
||||
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
|
||||
|
||||
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
|
||||
|
||||
AllowSyncFunc func(string, string) bool
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
@@ -890,3 +892,7 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) AllowSync(_, _ string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1579,7 +1579,6 @@ func Test_LoginPeer(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
setupKey string
|
||||
wireGuardPubKey string
|
||||
expectExtraDNSLabelsMismatch bool
|
||||
extraDNSLabels []string
|
||||
expectLoginError bool
|
||||
|
||||
@@ -4,19 +4,19 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID)
|
||||
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
|
||||
}
|
||||
|
||||
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
|
||||
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error {
|
||||
func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
|
||||
// routes can have both peer and peer_groups
|
||||
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains)
|
||||
prefix := checkRoute.Network
|
||||
domains := checkRoute.Domains
|
||||
|
||||
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// lets remember all the peers and the peer groups from routesWithPrefix
|
||||
seenPeers := make(map[string]bool)
|
||||
@@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
for _, prefixRoute := range routesWithPrefix {
|
||||
// we skip route(s) with the same network ID as we want to allow updating of the existing route
|
||||
// when creating a new route routeID is newly generated so nothing will be skipped
|
||||
if routeID == prefixRoute.ID {
|
||||
if checkRoute.ID == prefixRoute.ID {
|
||||
continue
|
||||
}
|
||||
|
||||
if prefixRoute.Peer != "" {
|
||||
seenPeers[string(prefixRoute.ID)] = true
|
||||
}
|
||||
|
||||
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, groupID := range prefixRoute.PeerGroups {
|
||||
seenPeerGroups[groupID] = true
|
||||
|
||||
group := account.GetGroup(groupID)
|
||||
if group == nil {
|
||||
group, ok := peerGroupsMap[groupID]
|
||||
if !ok || group == nil {
|
||||
return status.Errorf(
|
||||
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
|
||||
getRouteDescriptor(prefix, domains), groupID,
|
||||
@@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
}
|
||||
|
||||
if peerID != "" {
|
||||
if peerID := checkRoute.Peer; peerID != "" {
|
||||
// check that peerID exists and is not in any route as single peer or part of the group
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
|
||||
if _, ok := seenPeers[peerID]; ok {
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
|
||||
@@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
|
||||
// check that peerGroupIDs are not in any route peerGroups list
|
||||
for _, groupID := range peerGroupIDs {
|
||||
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again.
|
||||
|
||||
for _, groupID := range checkRoute.PeerGroups {
|
||||
group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
|
||||
if _, ok := seenPeerGroups[groupID]; ok {
|
||||
return status.Errorf(
|
||||
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
|
||||
@@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
|
||||
}
|
||||
|
||||
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
|
||||
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, id := range group.Peers {
|
||||
if _, ok := seenPeers[id]; ok {
|
||||
peer := account.GetPeer(id)
|
||||
if peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
peer, ok := peersMap[id]
|
||||
if !ok || peer == nil {
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
|
||||
}
|
||||
|
||||
return status.Errorf(status.AlreadyExists,
|
||||
"failed to add route with %s - peer %s from the group %s already has this route",
|
||||
getRouteDescriptor(prefix, domains), peer.Name, group.Name)
|
||||
@@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(domains) > 0 && prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
|
||||
if len(domains) == 0 && !prefix.IsValid() {
|
||||
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix")
|
||||
}
|
||||
var newRoute *route.Route
|
||||
var updateAccountPeers bool
|
||||
|
||||
if len(domains) > 0 {
|
||||
prefix = getPlaceholderIP()
|
||||
}
|
||||
|
||||
if peerID != "" && len(peerGroupIDs) != 0 {
|
||||
return nil, status.Errorf(
|
||||
status.InvalidArgument,
|
||||
"peer with ID %s and peers group %s should not be provided at the same time",
|
||||
peerID, peerGroupIDs)
|
||||
}
|
||||
|
||||
var newRoute route.Route
|
||||
newRoute.ID = route.ID(xid.New().String())
|
||||
|
||||
if len(peerGroupIDs) > 0 {
|
||||
err = validateGroups(peerGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
newRoute = &route.Route{
|
||||
ID: route.ID(xid.New().String()),
|
||||
AccountID: accountID,
|
||||
Network: prefix,
|
||||
Domains: domains,
|
||||
KeepRoute: keepRoute,
|
||||
NetID: netID,
|
||||
Description: description,
|
||||
Peer: peerID,
|
||||
PeerGroups: peerGroupIDs,
|
||||
NetworkType: networkType,
|
||||
Masquerade: masquerade,
|
||||
Metric: metric,
|
||||
Enabled: enabled,
|
||||
Groups: groups,
|
||||
AccessControlGroups: accessControlGroupIDs,
|
||||
}
|
||||
}
|
||||
|
||||
if len(accessControlGroupIDs) > 0 {
|
||||
err = validateGroups(accessControlGroupIDs, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains)
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
err = validateGroups(groups, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newRoute.Peer = peerID
|
||||
newRoute.PeerGroups = peerGroupIDs
|
||||
newRoute.Network = prefix
|
||||
newRoute.Domains = domains
|
||||
newRoute.NetworkType = networkType
|
||||
newRoute.Description = description
|
||||
newRoute.NetID = netID
|
||||
newRoute.Masquerade = masquerade
|
||||
newRoute.Metric = metric
|
||||
newRoute.Enabled = enabled
|
||||
newRoute.Groups = groups
|
||||
newRoute.KeepRoute = keepRoute
|
||||
newRoute.AccessControlGroups = accessControlGroupIDs
|
||||
|
||||
if account.Routes == nil {
|
||||
account.Routes = make(map[route.ID]*route.Route)
|
||||
}
|
||||
|
||||
account.Routes[newRoute.ID] = &newRoute
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if am.isRouteChangeAffectPeers(account, &newRoute) {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
|
||||
|
||||
return &newRoute, nil
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return newRoute, nil
|
||||
}
|
||||
|
||||
// SaveRoute saves route
|
||||
@@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var oldRoute *route.Route
|
||||
var oldRouteAffectsPeers bool
|
||||
var newRouteAffectsPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
routeToSave.AccountID = accountID
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
if oldRouteAffectsPeers || newRouteAffectsPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var route *route.Route
|
||||
var updateAccountPeers bool
|
||||
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
|
||||
})
|
||||
|
||||
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
|
||||
|
||||
if updateAccountPeers {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
}
|
||||
|
||||
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
|
||||
if routeToSave == nil {
|
||||
return status.Errorf(status.InvalidArgument, "route provided is nil")
|
||||
}
|
||||
@@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
|
||||
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
|
||||
}
|
||||
@@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
|
||||
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
|
||||
}
|
||||
|
||||
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
|
||||
}
|
||||
|
||||
// validateRouteGroups validates the route groups and returns the validated groups map.
|
||||
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
|
||||
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
|
||||
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(routeToSave.PeerGroups) > 0 {
|
||||
err = validateGroups(routeToSave.PeerGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(routeToSave.AccessControlGroups) > 0 {
|
||||
err = validateGroups(routeToSave.AccessControlGroups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains)
|
||||
if err != nil {
|
||||
return err
|
||||
if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = validateGroups(routeToSave.Groups, account.Groups)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oldRoute := account.Routes[routeToSave.ID]
|
||||
account.Routes[routeToSave.ID] = routeToSave
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
routy := account.Routes[routeID]
|
||||
if routy == nil {
|
||||
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
|
||||
}
|
||||
delete(account.Routes, routeID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
|
||||
|
||||
if am.isRouteChangeAffectPeers(account, routy) {
|
||||
am.UpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoutes returns a list of routes from account
|
||||
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
|
||||
if err != nil {
|
||||
return nil, status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
return groupsMap, nil
|
||||
}
|
||||
|
||||
func toProtocolRoute(route *route.Route) *proto.Route {
|
||||
@@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
|
||||
return &portInfo
|
||||
}
|
||||
|
||||
// isRouteChangeAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers
|
||||
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool {
|
||||
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != ""
|
||||
// areRouteChangesAffectPeers checks if a given route affects peers by determining
|
||||
// if it has a routing peer, distribution, or peer groups that include peers.
|
||||
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
|
||||
if route.Peer != "" {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if hasPeers {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
|
||||
}
|
||||
|
||||
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
|
||||
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
|
||||
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := make([]*route.Route, 0)
|
||||
for _, r := range accountRoutes {
|
||||
dynamic := r.IsDynamic()
|
||||
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
|
||||
!dynamic && r.Network.String() == prefix.String() {
|
||||
routes = append(routes, r)
|
||||
}
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
@@ -42,7 +42,10 @@ const (
|
||||
// Type is a type of the Error
|
||||
type Type int32
|
||||
|
||||
var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found")
|
||||
var (
|
||||
ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found")
|
||||
ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
|
||||
)
|
||||
|
||||
// Error is an internal error
|
||||
type Error struct {
|
||||
@@ -227,3 +230,7 @@ func NewUserRoleNotFoundError(role string) error {
|
||||
func NewOperationNotFoundError(operation operations.Operation) error {
|
||||
return Errorf(NotFound, "operation: %s not found", operation)
|
||||
}
|
||||
|
||||
func NewRouteNotFoundError(routeID string) error {
|
||||
return Errorf(NotFound, "route: %s not found", routeID)
|
||||
}
|
||||
|
||||
@@ -23,8 +23,6 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -34,6 +32,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
|
||||
|
||||
// GetAccountRoutes retrieves network routes for an account.
|
||||
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
|
||||
return getRecords[*route.Route](s.db, lockStrength, accountID)
|
||||
var routes []*route.Route
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Find(&routes, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get routes from store")
|
||||
}
|
||||
|
||||
return routes, nil
|
||||
}
|
||||
|
||||
// GetRouteByID retrieves a route by its ID and account ID.
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) {
|
||||
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID)
|
||||
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
|
||||
var route *route.Route
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&route, accountAndIDQueryCondition, accountID, routeID)
|
||||
if err := result.Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewRouteNotFoundError(routeID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
|
||||
return nil, status.Errorf(status.Internal, "failed to get route from store")
|
||||
}
|
||||
|
||||
return route, nil
|
||||
}
|
||||
|
||||
// SaveRoute saves a route to the database.
|
||||
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to save route to store")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes a route from the database.
|
||||
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
|
||||
if err := result.Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
|
||||
return status.Errorf(status.Internal, "failed to delete route from store")
|
||||
}
|
||||
|
||||
if result.RowsAffected == 0 {
|
||||
return status.NewRouteNotFoundError(routeID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAccountSetupKeys retrieves setup keys for an account.
|
||||
@@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
|
||||
return nil
|
||||
}
|
||||
|
||||
// getRecords retrieves records from the database based on the account ID.
|
||||
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
|
||||
tx := db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var record []T
|
||||
|
||||
result := tx.Find(&record, accountIDCondition, accountID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// getRecordByID retrieves a record by its ID and account ID from the database.
|
||||
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
|
||||
tx := db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var record T
|
||||
|
||||
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
|
||||
First(&record, accountAndIDQueryCondition, accountID, recordID)
|
||||
if err := result.Error; err != nil {
|
||||
parts := strings.Split(fmt.Sprintf("%T", record), ".")
|
||||
recordType := parts[len(parts)-1]
|
||||
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
|
||||
}
|
||||
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
|
||||
}
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// SaveDNSSettings saves the DNS settings to the store.
|
||||
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
|
||||
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
|
||||
|
||||
@@ -19,21 +19,17 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
route2 "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
|
||||
@@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 8003, len(accountGroups))
|
||||
}
|
||||
func TestSqlStore_GetAccountRoutes(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID string
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "retrieve routes by existing account ID",
|
||||
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "non-existing account ID",
|
||||
accountID: "nonexistent",
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "empty account ID",
|
||||
accountID: "",
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, routes, tt.expectedCount)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_GetRouteByID(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
tests := []struct {
|
||||
name string
|
||||
routeID string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "retrieve existing route",
|
||||
routeID: "ct03t427qv97vmtmglog",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "retrieve non-existing route",
|
||||
routeID: "non-existing",
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "retrieve with empty route ID",
|
||||
routeID: "",
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
|
||||
if tt.expectError {
|
||||
require.Error(t, err)
|
||||
sErr, ok := status.FromError(err)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, sErr.Type(), status.NotFound)
|
||||
require.Nil(t, route)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, route)
|
||||
require.Equal(t, tt.routeID, string(route.ID))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSqlStore_SaveRoute(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
|
||||
route := &route2.Route{
|
||||
ID: "route-id",
|
||||
AccountID: accountID,
|
||||
Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||
NetID: "netID",
|
||||
PeerGroups: []string{"routeA"},
|
||||
NetworkType: route2.IPv4Network,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{"groupA"},
|
||||
AccessControlGroups: []string{},
|
||||
}
|
||||
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
|
||||
require.NoError(t, err)
|
||||
|
||||
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, route, saveRoute)
|
||||
|
||||
}
|
||||
|
||||
func TestSqlStore_DeleteRoute(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
require.NoError(t, err)
|
||||
|
||||
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
|
||||
routeID := "ct03t427qv97vmtmglog"
|
||||
|
||||
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
|
||||
require.NoError(t, err)
|
||||
|
||||
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, route)
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountMeta(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
|
||||
|
||||
@@ -145,7 +145,9 @@ type Store interface {
|
||||
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
|
||||
|
||||
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error)
|
||||
GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
|
||||
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
|
||||
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
|
||||
|
||||
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
|
||||
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)
|
||||
|
||||
@@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
|
||||
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
|
||||
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
|
||||
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
|
||||
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
|
||||
INSERT INTO installations VALUES(1,'');
|
||||
|
||||
Reference in New Issue
Block a user