Compare commits

...

31 Commits

Author SHA1 Message Date
crn4
c7ababbf45 Merge branch 'main' into test/multiple-peer-logging 2025-06-18 19:04:47 +02:00
Viktor Liu
23b5d45b68 [client] Fix port range squashing (#4007) 2025-06-18 18:56:48 +02:00
crn4
e993b633bd Merge branch 'main' into test/multiple-peer-logging 2025-06-18 18:35:11 +02:00
crn4
6aaec1002d minor fix 2025-06-18 18:22:58 +02:00
crn4
ebf3d26c91 another log and different approach for meta calc 2025-06-18 18:16:45 +02:00
Viktor Liu
0e5dc9d412 [client] Add more Android advanced settings (#4001) 2025-06-18 17:23:23 +02:00
Zoltan Papp
91f7ee6a3c Fix route notification
On Android ignore the dynamic roots in the route notifications
2025-06-18 16:49:03 +02:00
Bethuel Mmbaga
7c6b85b4cb [management] Refactor routes to use store methods (#2928) 2025-06-18 16:40:29 +03:00
crn4
9dc9402deb Merge branch 'main' into test/multiple-peer-logging 2025-06-17 19:21:05 +02:00
crn4
41a9e45c68 add metahash to sync log 2025-06-17 18:56:59 +02:00
crn4
641891e931 minor fix - logs for incrementing number of attempts 2025-06-17 18:05:52 +02:00
crn4
c43ddddcdb added delimeter to hash 2025-06-17 17:56:06 +02:00
crn4
0a9d09267a minor fix - removed errors 2025-06-17 17:53:44 +02:00
crn4
05733b00c1 test logging to asses number of peers with the same pub key 2025-06-17 17:50:09 +02:00
hakansa
08c9107c61 [client] fix connection state handling (#3995)
[client] fix connection state handling (#3995)
2025-06-17 17:14:08 +03:00
hakansa
81d83245e1 [client] Fix logic in updateStatus to correctly handle connection state (#3994)
[client] Fix logic in updateStatus to correctly handle connection state (#3994)
2025-06-17 17:02:04 +03:00
Maycon Santos
af2b427751 [management] Avoid recalculating next peer expiration (#3991)
* Avoid recalculating next peer expiration

- Check if an account schedule is already running
- Cancel executing schedules only when changes occurs
- Add more context info to logs

* fix tests
2025-06-17 15:14:11 +02:00
hakansa
f61ebdb3bc [client] Fix DNS Interceptor Build Error (#3993)
[client] Fix DNS Interceptor Build Error
2025-06-17 16:07:14 +03:00
Viktor Liu
de7384e8ea [client] Tighten allowed domains for dns forwarder (#3978) 2025-06-17 14:03:00 +02:00
Viktor Liu
75c1be69cf [client] Prioritze the local resolver in the dns handler chain (#3965) 2025-06-17 14:02:30 +02:00
hakansa
424ae28de9 [client] Fix UI Download URL (#3990)
[client] Fix UI Download URL
2025-06-17 11:55:48 +03:00
Viktor Liu
d4a800edd5 [client] Fix status recorder panic (#3988) 2025-06-17 01:20:26 +02:00
crn4
0a5f751343 removed unused test 2025-06-16 18:10:59 +02:00
crn4
b2a7a4c6d4 unrecoverable error for login 2025-06-16 17:48:13 +02:00
crn4
cfdaa82fea login filter check on login 2025-06-16 15:21:00 +02:00
crn4
c332ff0a47 fixed error wrap 2025-06-16 13:31:01 +02:00
crn4
6cd77cc17c minor fix 2025-06-16 13:16:52 +02:00
crn4
19835dc6d5 change grpc code to permission denied 2025-06-16 13:10:38 +02:00
crn4
3cd21cc7e5 Merge branch 'main' into fix/login-filter 2025-06-16 09:41:48 +02:00
crn4
4619d39e17 updated approach to filtering - through sync 2025-06-16 09:40:25 +02:00
crn4
5b09804a17 added login filter to filter different peers with the same pub key 2025-06-12 15:56:32 +01:00
39 changed files with 2092 additions and 470 deletions

View File

@@ -4,12 +4,12 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
) )
// Preferences export a subset of the internal config for gomobile // Preferences exports a subset of the internal config for gomobile
type Preferences struct { type Preferences struct {
configInput internal.ConfigInput configInput internal.ConfigInput
} }
// NewPreferences create new Preferences instance // NewPreferences creates a new Preferences instance
func NewPreferences(configPath string) *Preferences { func NewPreferences(configPath string) *Preferences {
ci := internal.ConfigInput{ ci := internal.ConfigInput{
ConfigPath: configPath, ConfigPath: configPath,
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
return &Preferences{ci} return &Preferences{ci}
} }
// GetManagementURL read url from config file // GetManagementURL reads URL from config file
func (p *Preferences) GetManagementURL() (string, error) { func (p *Preferences) GetManagementURL() (string, error) {
if p.configInput.ManagementURL != "" { if p.configInput.ManagementURL != "" {
return p.configInput.ManagementURL, nil return p.configInput.ManagementURL, nil
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
return cfg.ManagementURL.String(), err return cfg.ManagementURL.String(), err
} }
// SetManagementURL store the given url and wait for commit // SetManagementURL stores the given URL and waits for commit
func (p *Preferences) SetManagementURL(url string) { func (p *Preferences) SetManagementURL(url string) {
p.configInput.ManagementURL = url p.configInput.ManagementURL = url
} }
// GetAdminURL read url from config file // GetAdminURL reads URL from config file
func (p *Preferences) GetAdminURL() (string, error) { func (p *Preferences) GetAdminURL() (string, error) {
if p.configInput.AdminURL != "" { if p.configInput.AdminURL != "" {
return p.configInput.AdminURL, nil return p.configInput.AdminURL, nil
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
return cfg.AdminURL.String(), err return cfg.AdminURL.String(), err
} }
// SetAdminURL store the given url and wait for commit // SetAdminURL stores the given URL and waits for commit
func (p *Preferences) SetAdminURL(url string) { func (p *Preferences) SetAdminURL(url string) {
p.configInput.AdminURL = url p.configInput.AdminURL = url
} }
// GetPreSharedKey read preshared key from config file // GetPreSharedKey reads pre-shared key from config file
func (p *Preferences) GetPreSharedKey() (string, error) { func (p *Preferences) GetPreSharedKey() (string, error) {
if p.configInput.PreSharedKey != nil { if p.configInput.PreSharedKey != nil {
return *p.configInput.PreSharedKey, nil return *p.configInput.PreSharedKey, nil
@@ -66,17 +66,17 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
return cfg.PreSharedKey, err return cfg.PreSharedKey, err
} }
// SetPreSharedKey store the given key and wait for commit // SetPreSharedKey stores the given key and waits for commit
func (p *Preferences) SetPreSharedKey(key string) { func (p *Preferences) SetPreSharedKey(key string) {
p.configInput.PreSharedKey = &key p.configInput.PreSharedKey = &key
} }
// SetRosenpassEnabled store if rosenpass is enabled // SetRosenpassEnabled stores whether Rosenpass is enabled
func (p *Preferences) SetRosenpassEnabled(enabled bool) { func (p *Preferences) SetRosenpassEnabled(enabled bool) {
p.configInput.RosenpassEnabled = &enabled p.configInput.RosenpassEnabled = &enabled
} }
// GetRosenpassEnabled read rosenpass enabled from config file // GetRosenpassEnabled reads Rosenpass enabled status from config file
func (p *Preferences) GetRosenpassEnabled() (bool, error) { func (p *Preferences) GetRosenpassEnabled() (bool, error) {
if p.configInput.RosenpassEnabled != nil { if p.configInput.RosenpassEnabled != nil {
return *p.configInput.RosenpassEnabled, nil return *p.configInput.RosenpassEnabled, nil
@@ -89,12 +89,12 @@ func (p *Preferences) GetRosenpassEnabled() (bool, error) {
return cfg.RosenpassEnabled, err return cfg.RosenpassEnabled, err
} }
// SetRosenpassPermissive store the given permissive and wait for commit // SetRosenpassPermissive stores the given permissive setting and waits for commit
func (p *Preferences) SetRosenpassPermissive(permissive bool) { func (p *Preferences) SetRosenpassPermissive(permissive bool) {
p.configInput.RosenpassPermissive = &permissive p.configInput.RosenpassPermissive = &permissive
} }
// GetRosenpassPermissive read rosenpass permissive from config file // GetRosenpassPermissive reads Rosenpass permissive setting from config file
func (p *Preferences) GetRosenpassPermissive() (bool, error) { func (p *Preferences) GetRosenpassPermissive() (bool, error) {
if p.configInput.RosenpassPermissive != nil { if p.configInput.RosenpassPermissive != nil {
return *p.configInput.RosenpassPermissive, nil return *p.configInput.RosenpassPermissive, nil
@@ -107,7 +107,119 @@ func (p *Preferences) GetRosenpassPermissive() (bool, error) {
return cfg.RosenpassPermissive, err return cfg.RosenpassPermissive, err
} }
// Commit write out the changes into config file // GetDisableClientRoutes reads disable client routes setting from config file
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
if p.configInput.DisableClientRoutes != nil {
return *p.configInput.DisableClientRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableClientRoutes, err
}
// SetDisableClientRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableClientRoutes(disable bool) {
p.configInput.DisableClientRoutes = &disable
}
// GetDisableServerRoutes reads disable server routes setting from config file
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
if p.configInput.DisableServerRoutes != nil {
return *p.configInput.DisableServerRoutes, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableServerRoutes, err
}
// SetDisableServerRoutes stores the given value and waits for commit
func (p *Preferences) SetDisableServerRoutes(disable bool) {
p.configInput.DisableServerRoutes = &disable
}
// GetDisableDNS reads disable DNS setting from config file
func (p *Preferences) GetDisableDNS() (bool, error) {
if p.configInput.DisableDNS != nil {
return *p.configInput.DisableDNS, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableDNS, err
}
// SetDisableDNS stores the given value and waits for commit
func (p *Preferences) SetDisableDNS(disable bool) {
p.configInput.DisableDNS = &disable
}
// GetDisableFirewall reads disable firewall setting from config file
func (p *Preferences) GetDisableFirewall() (bool, error) {
if p.configInput.DisableFirewall != nil {
return *p.configInput.DisableFirewall, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.DisableFirewall, err
}
// SetDisableFirewall stores the given value and waits for commit
func (p *Preferences) SetDisableFirewall(disable bool) {
p.configInput.DisableFirewall = &disable
}
// GetServerSSHAllowed reads server SSH allowed setting from config file
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
if p.configInput.ServerSSHAllowed != nil {
return *p.configInput.ServerSSHAllowed, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
if cfg.ServerSSHAllowed == nil {
// Default to false for security on Android
return false, nil
}
return *cfg.ServerSSHAllowed, err
}
// SetServerSSHAllowed stores the given value and waits for commit
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
p.configInput.ServerSSHAllowed = &allowed
}
// GetBlockInbound reads block inbound setting from config file
func (p *Preferences) GetBlockInbound() (bool, error) {
if p.configInput.BlockInbound != nil {
return *p.configInput.BlockInbound, nil
}
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
if err != nil {
return false, err
}
return cfg.BlockInbound, err
}
// SetBlockInbound stores the given value and waits for commit
func (p *Preferences) SetBlockInbound(block bool) {
p.configInput.BlockInbound = &block
}
// Commit writes out the changes to the config file
func (p *Preferences) Commit() error { func (p *Preferences) Commit() error {
_, err := internal.UpdateOrCreateConfig(p.configInput) _, err := internal.UpdateOrCreateConfig(p.configInput)
return err return err

View File

@@ -38,5 +38,5 @@ func init() {
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false, upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+ "Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
"This overrides any policies received from the management service.") "This overrides any policies received from the management service.")
} }

View File

@@ -24,6 +24,7 @@ type WGTunDevice struct {
mtu int mtu int
iceBind *bind.ICEBind iceBind *bind.ICEBind
tunAdapter TunAdapter tunAdapter TunAdapter
disableDNS bool
name string name string
device *device.Device device *device.Device
@@ -32,7 +33,7 @@ type WGTunDevice struct {
configurer WGConfigurer configurer WGConfigurer
} }
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice { func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
return &WGTunDevice{ return &WGTunDevice{
address: address, address: address,
port: port, port: port,
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
mtu: mtu, mtu: mtu,
iceBind: iceBind, iceBind: iceBind,
tunAdapter: tunAdapter, tunAdapter: tunAdapter,
disableDNS: disableDNS,
} }
} }
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
routesString := routesToString(routes) routesString := routesToString(routes)
searchDomainsToString := searchDomainsToString(searchDomains) searchDomainsToString := searchDomainsToString(searchDomains)
// Skip DNS configuration when DisableDNS is enabled
if t.disableDNS {
log.Info("DNS is disabled, skipping DNS and search domain configuration")
dns = ""
searchDomainsToString = ""
}
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString) fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
if err != nil { if err != nil {
log.Errorf("failed to create Android interface: %s", err) log.Errorf("failed to create Android interface: %s", err)

View File

@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
MobileArgs *device.MobileIFaceArguments MobileArgs *device.MobileIFaceArguments
TransportNet transport.Net TransportNet transport.Net
FilterFn bind.FilterFn FilterFn bind.FilterFn
DisableDNS bool
} }
// WGIface represents an interface instance // WGIface represents an interface instance

View File

@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
wgIFace := &WGIface{ wgIFace := &WGIface{
userspaceBind: true, userspaceBind: true,
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter), tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
wgProxyFactory: wgproxy.NewUSPFactory(iceBind), wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
} }
return wgIFace, nil return wgIFace, nil

View File

@@ -398,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
// //
// We zeroed this to notify squash function that this protocol can't be squashed. // We zeroed this to notify squash function that this protocol can't be squashed.
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) { addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != "" hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
if drop { r.Port != "" || !portInfoEmpty(r.PortInfo)
if hasPortRestrictions {
// Don't squash rules with port restrictions
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}} protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
return return
} }
if _, ok := protocols[r.Protocol]; !ok { if _, ok := protocols[r.Protocol]; !ok {
protocols[r.Protocol] = &protoMatch{ protocols[r.Protocol] = &protoMatch{
ips: map[string]int{}, ips: map[string]int{},

View File

@@ -330,6 +330,434 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
assert.Equal(t, len(networkMap.FirewallRules), len(rules)) assert.Equal(t, len(networkMap.FirewallRules), len(rules))
} }
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
tests := []struct {
name string
rules []*mgmProto.FirewallRule
expectedCount int
description string
}{
{
name: "should not squash rules with port ranges",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
},
},
expectedCount: 4,
description: "Rules with port ranges should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with specific ports",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
},
expectedCount: 4,
description: "Rules with specific ports should not be squashed even if they cover all peers",
},
{
name: "should not squash rules with legacy port field",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
},
expectedCount: 4,
description: "Rules with legacy port field should not be squashed",
},
{
name: "should not squash rules with DROP action",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_DROP,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "Rules with DROP action should not be squashed",
},
{
name: "should squash rules without port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 1,
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
},
{
name: "mixed rules should not squash protocol with port restrictions",
rules: []*mgmProto.FirewallRule{
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
PortInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
},
},
expectedCount: 4,
description: "TCP should not be squashed because one rule has port restrictions",
},
{
name: "should squash UDP but not TCP when TCP has port restrictions",
rules: []*mgmProto.FirewallRule{
// TCP rules with port restrictions - should NOT be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_TCP,
Port: "443",
},
// UDP rules without port restrictions - SHOULD be squashed
{
PeerIP: "10.93.0.1",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.2",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.3",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
{
PeerIP: "10.93.0.4",
Direction: mgmProto.RuleDirection_IN,
Action: mgmProto.RuleAction_ACCEPT,
Protocol: mgmProto.RuleProtocol_UDP,
},
},
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
networkMap := &mgmProto.NetworkMap{
RemotePeers: []*mgmProto.RemotePeerConfig{
{AllowedIps: []string{"10.93.0.1"}},
{AllowedIps: []string{"10.93.0.2"}},
{AllowedIps: []string{"10.93.0.3"}},
{AllowedIps: []string{"10.93.0.4"}},
},
FirewallRules: tt.rules,
}
manager := &DefaultManager{}
rules, _ := manager.squashAcceptRules(networkMap)
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
// For squashed rules, verify we get the expected 0.0.0.0 rule
if tt.expectedCount == 1 {
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
}
})
}
}
func TestPortInfoEmpty(t *testing.T) {
tests := []struct {
name string
portInfo *mgmProto.PortInfo
expected bool
}{
{
name: "nil PortInfo should be empty",
portInfo: nil,
expected: true,
},
{
name: "PortInfo with zero port should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 0,
},
},
expected: true,
},
{
name: "PortInfo with valid port should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Port{
Port: 80,
},
},
expected: false,
},
{
name: "PortInfo with nil range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: nil,
},
},
expected: true,
},
{
name: "PortInfo with zero start range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 0,
End: 100,
},
},
},
expected: true,
},
{
name: "PortInfo with zero end range should be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 80,
End: 0,
},
},
},
expected: true,
},
{
name: "PortInfo with valid range should not be empty",
portInfo: &mgmProto.PortInfo{
PortSelection: &mgmProto.PortInfo_Range_{
Range: &mgmProto.PortInfo_Range{
Start: 8080,
End: 8090,
},
},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := portInfoEmpty(tt.portInfo)
assert.Equal(t, tt.expected, result)
})
}
}
func TestDefaultManagerEnableSSHRules(t *testing.T) { func TestDefaultManagerEnableSSHRules(t *testing.T) {
networkMap := &mgmProto.NetworkMap{ networkMap := &mgmProto.NetworkMap{
PeerConfig: &mgmProto.PeerConfig{ PeerConfig: &mgmProto.PeerConfig{

View File

@@ -223,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{ config := &Config{
// defaults to false only for new (post 0.26) configurations // defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(), ServerSSHAllowed: util.False(),
// default to disabling server routes on Android for security
DisableServerRoutes: runtime.GOOS == "android",
} }
if _, err := config.apply(input); err != nil { if _, err := config.apply(input); err != nil {
@@ -416,9 +418,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
config.ServerSSHAllowed = input.ServerSSHAllowed config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true updated = true
} else if config.ServerSSHAllowed == nil { } else if config.ServerSSHAllowed == nil {
// enables SSH for configs from old versions to preserve backwards compatibility if runtime.GOOS == "android" {
log.Infof("falling back to enabled SSH server for pre-existing configuration") // default to disabled SSH on Android for security
config.ServerSSHAllowed = util.True() log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true updated = true
} }

View File

@@ -11,9 +11,10 @@ import (
) )
const ( const (
PriorityDNSRoute = 100 PriorityLocal = 100
PriorityMatchDomain = 50 PriorityDNSRoute = 75
PriorityDefault = 1 PriorityUpstream = 50
PriorityDefault = 1
) )
type SubdomainMatcher interface { type SubdomainMatcher interface {

View File

@@ -22,7 +22,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
// Setup handlers with different priorities // Setup handlers with different priorities
chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault) chain.AddHandler("example.com.", defaultHandler, nbdns.PriorityDefault)
chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", matchDomainHandler, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, nbdns.PriorityDNSRoute)
// Create test request // Create test request
@@ -200,7 +200,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "*.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "*.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "test.example.com.", queryDomain: "test.example.com.",
@@ -214,7 +214,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
priority int priority int
}{ }{
{pattern: "*.example.com.", priority: nbdns.PriorityDefault}, {pattern: "*.example.com.", priority: nbdns.PriorityDefault},
{pattern: "test.example.com.", priority: nbdns.PriorityMatchDomain}, {pattern: "test.example.com.", priority: nbdns.PriorityUpstream},
{pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute}, {pattern: "*.test.example.com.", priority: nbdns.PriorityDNSRoute},
}, },
queryDomain: "sub.test.example.com.", queryDomain: "sub.test.example.com.",
@@ -281,7 +281,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
// Add handlers in priority order // Add handlers in priority order
chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute) chain.AddHandler("example.com.", handler1, nbdns.PriorityDNSRoute)
chain.AddHandler("example.com.", handler2, nbdns.PriorityMatchDomain) chain.AddHandler("example.com.", handler2, nbdns.PriorityUpstream)
chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault) chain.AddHandler("example.com.", handler3, nbdns.PriorityDefault)
// Create test request // Create test request
@@ -344,13 +344,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: true, nbdns.PriorityUpstream: true,
}, },
}, },
{ {
@@ -361,13 +361,13 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: true, nbdns.PriorityDNSRoute: true,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
}, },
}, },
{ {
@@ -378,16 +378,16 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
priority int priority int
}{ }{
{"add", "example.com.", nbdns.PriorityDNSRoute}, {"add", "example.com.", nbdns.PriorityDNSRoute},
{"add", "example.com.", nbdns.PriorityMatchDomain}, {"add", "example.com.", nbdns.PriorityUpstream},
{"add", "example.com.", nbdns.PriorityDefault}, {"add", "example.com.", nbdns.PriorityDefault},
{"remove", "example.com.", nbdns.PriorityDNSRoute}, {"remove", "example.com.", nbdns.PriorityDNSRoute},
{"remove", "example.com.", nbdns.PriorityMatchDomain}, {"remove", "example.com.", nbdns.PriorityUpstream},
}, },
query: "example.com.", query: "example.com.",
expectedCalls: map[int]bool{ expectedCalls: map[int]bool{
nbdns.PriorityDNSRoute: false, nbdns.PriorityDNSRoute: false,
nbdns.PriorityMatchDomain: false, nbdns.PriorityUpstream: false,
nbdns.PriorityDefault: true, nbdns.PriorityDefault: true,
}, },
}, },
} }
@@ -454,7 +454,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
// Add handlers in mixed order // Add handlers in mixed order
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault) chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute) chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain) chain.AddHandler(testDomain, matchHandler, nbdns.PriorityUpstream)
// Test 1: Initial state // Test 1: Initial state
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
@@ -490,7 +490,7 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
defaultHandler.Calls = nil defaultHandler.Calls = nil
// Test 3: Remove middle priority handler // Test 3: Remove middle priority handler
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain) chain.RemoveHandler(testDomain, nbdns.PriorityUpstream)
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}} w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
// Now lowest priority handler (defaultHandler) should be called // Now lowest priority handler (defaultHandler) should be called
@@ -607,7 +607,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
shouldMatch bool shouldMatch bool
}{ }{
{"EXAMPLE.COM.", nbdns.PriorityDefault, false, false}, {"EXAMPLE.COM.", nbdns.PriorityDefault, false, false},
{"example.com.", nbdns.PriorityMatchDomain, false, false}, {"example.com.", nbdns.PriorityUpstream, false, false},
{"Example.Com.", nbdns.PriorityDNSRoute, false, true}, {"Example.Com.", nbdns.PriorityDNSRoute, false, true},
}, },
query: "example.com.", query: "example.com.",
@@ -702,8 +702,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -717,8 +717,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -732,10 +732,10 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "sub.example.com.", nbdns.PriorityUpstream, true},
{"add", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "test.sub.example.com.", nbdns.PriorityUpstream, false},
{"remove", "test.sub.example.com.", nbdns.PriorityMatchDomain, false}, {"remove", "test.sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "test.sub.example.com.", query: "test.sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",
@@ -749,7 +749,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
{"add", "example.com.", nbdns.PriorityDNSRoute, true}, {"add", "example.com.", nbdns.PriorityDNSRoute, true},
}, },
query: "sub.example.com.", query: "sub.example.com.",
@@ -764,9 +764,9 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
priority int priority int
subdomain bool subdomain bool
}{ }{
{"add", "example.com.", nbdns.PriorityMatchDomain, true}, {"add", "example.com.", nbdns.PriorityUpstream, true},
{"add", "other.example.com.", nbdns.PriorityMatchDomain, true}, {"add", "other.example.com.", nbdns.PriorityUpstream, true},
{"add", "sub.example.com.", nbdns.PriorityMatchDomain, false}, {"add", "sub.example.com.", nbdns.PriorityUpstream, false},
}, },
query: "sub.example.com.", query: "sub.example.com.",
expectedMatch: "sub.example.com.", expectedMatch: "sub.example.com.",

View File

@@ -527,7 +527,7 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
muxUpdates = append(muxUpdates, handlerWrapper{ muxUpdates = append(muxUpdates, handlerWrapper{
domain: customZone.Domain, domain: customZone.Domain,
handler: s.localResolver, handler: s.localResolver,
priority: PriorityMatchDomain, priority: PriorityLocal,
}) })
for _, record := range customZone.Records { for _, record := range customZone.Records {
@@ -566,7 +566,7 @@ func (s *DefaultServer) buildUpstreamHandlerUpdate(nameServerGroups []*nbdns.Nam
groupedNS := groupNSGroupsByDomain(nameServerGroups) groupedNS := groupNSGroupsByDomain(nameServerGroups)
for _, domainGroup := range groupedNS { for _, domainGroup := range groupedNS {
basePriority := PriorityMatchDomain basePriority := PriorityUpstream
if domainGroup.domain == nbdns.RootZone { if domainGroup.domain == nbdns.RootZone {
basePriority = PriorityDefault basePriority = PriorityDefault
} }
@@ -588,10 +588,14 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
// Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts // Decrement priority by handler index (0, 1, 2, ...) to avoid conflicts
priority := basePriority - i priority := basePriority - i
// Check if we're about to overlap with the next priority tier // Check if we're about to overlap with the next priority tier.
if basePriority == PriorityMatchDomain && priority <= PriorityDefault { // This boundary check ensures that the priority of upstream handlers does not conflict
// with the default priority tier. By decrementing the priority for each handler, we avoid
// overlaps, but if the calculated priority falls into the default tier, we skip the remaining
// handlers to maintain the integrity of the priority system.
if basePriority == PriorityUpstream && priority <= PriorityDefault {
log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers", log.Warnf("too many handlers for domain=%s, would overlap with default priority tier (diff=%d). Skipping remaining handlers",
domainGroup.domain, PriorityMatchDomain-PriorityDefault) domainGroup.domain, PriorityUpstream-PriorityDefault)
break break
} }

View File

@@ -164,12 +164,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
dummyHandler.ID(): handlerWrapper{ dummyHandler.ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
generateDummyHandler(".", nameServers).ID(): handlerWrapper{ generateDummyHandler(".", nameServers).ID(): handlerWrapper{
domain: nbdns.RootZone, domain: nbdns.RootZone,
@@ -186,7 +186,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -210,12 +210,12 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{ generateDummyHandler("netbird.io", nameServers).ID(): handlerWrapper{
domain: "netbird.io", domain: "netbird.io",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"local-resolver": handlerWrapper{ "local-resolver": handlerWrapper{
domain: "netbird.cloud", domain: "netbird.cloud",
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityLocal,
}, },
}, },
expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}}, expectedLocalQs: []dns.Question{{Name: zoneRecords[0].Name, Qtype: 1, Qclass: 1}},
@@ -305,7 +305,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -321,7 +321,7 @@ func TestUpdateDNSServer(t *testing.T) {
generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{ generateDummyHandler(zoneRecords[0].Name, nameServers).ID(): handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: dummyHandler, handler: dummyHandler,
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
initSerial: 0, initSerial: 0,
@@ -495,7 +495,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
"id1": handlerWrapper{ "id1": handlerWrapper{
domain: zoneRecords[0].Name, domain: zoneRecords[0].Name,
handler: &local.Resolver{}, handler: &local.Resolver{},
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
//dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}} //dnsServer.localResolver.RegisteredMap = local.RegistrationMap{local.BuildRecordKey("netbird.cloud", dns.ClassINET, dns.TypeA): struct{}{}}
@@ -978,7 +978,7 @@ func TestHandlerChain_DomainPriorities(t *testing.T) {
} }
chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute) chain.AddHandler("example.com.", dnsRouteHandler, PriorityDNSRoute)
chain.AddHandler("example.com.", upstreamHandler, PriorityMatchDomain) chain.AddHandler("example.com.", upstreamHandler, PriorityUpstream)
testCases := []struct { testCases := []struct {
name string name string
@@ -1059,14 +1059,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
} }
@@ -1093,21 +1093,21 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
"upstream-group2": { "upstream-group2": {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
"upstream-other": { "upstream-other": {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
} }
@@ -1128,7 +1128,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1146,7 +1146,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1164,7 +1164,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain + 1, priority: PriorityUpstream + 1,
}, },
// Keep existing groups with their original priorities // Keep existing groups with their original priorities
{ {
@@ -1172,14 +1172,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1199,14 +1199,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
// Add group3 with lowest priority // Add group3 with lowest priority
{ {
@@ -1214,7 +1214,7 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group3", Id: "upstream-group3",
}, },
priority: PriorityMatchDomain - 2, priority: PriorityUpstream - 2,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1335,14 +1335,14 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1360,28 +1360,28 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group1", Id: "upstream-group1",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "example.com", domain: "example.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-group2", Id: "upstream-group2",
}, },
priority: PriorityMatchDomain - 1, priority: PriorityUpstream - 1,
}, },
{ {
domain: "other.com", domain: "other.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-other", Id: "upstream-other",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
{ {
domain: "new.com", domain: "new.com",
handler: &mockHandler{ handler: &mockHandler{
Id: "upstream-new", Id: "upstream-new",
}, },
priority: PriorityMatchDomain, priority: PriorityUpstream,
}, },
}, },
expectedHandlers: map[string]string{ expectedHandlers: map[string]string{
@@ -1791,14 +1791,14 @@ func TestExtraDomainsRefCounting(t *testing.T) {
// Register domains from different handlers with same domain // Register domains from different handlers with same domain
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute) server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityUpstream)
// Verify refcount is 2 // Verify refcount is 2
zoneKey := toZone("shared.example.com") zoneKey := toZone("shared.example.com")
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice") assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
// Deregister one handler // Deregister one handler
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain) server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityUpstream)
// Verify refcount is 1 // Verify refcount is 1
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler") assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
@@ -1925,7 +1925,7 @@ func TestDomainCaseHandling(t *testing.T) {
} }
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault) server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain) server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityUpstream)
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized") assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
@@ -1945,3 +1945,111 @@ func TestDomainCaseHandling(t *testing.T) {
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent") assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present") assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
} }
func TestLocalResolverPriorityInServer(t *testing.T) {
server := &DefaultServer{
ctx: context.Background(),
wgInterface: &mocWGIface{},
handlerChain: NewHandlerChain(),
localResolver: local.NewResolver(),
service: &mockService{},
extraDomains: make(map[domain.Domain]int),
}
config := nbdns.Config{
ServiceEnable: true,
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
NameServerGroups: []*nbdns.NameServerGroup{
{
Domains: []string{"local.example.com"}, // Same domain as local records
NameServers: []nbdns.NameServer{
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: 53,
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
upstreamMuxUpdates, err := server.buildUpstreamHandlerUpdate(config.NameServerGroups)
assert.NoError(t, err)
// Verify that local handler has higher priority than upstream for same domain
var localPriority, upstreamPriority int
localFound, upstreamFound := false, false
for _, update := range localMuxUpdates {
if update.domain == "local.example.com" {
localPriority = update.priority
localFound = true
}
}
for _, update := range upstreamMuxUpdates {
if update.domain == "local.example.com" {
upstreamPriority = update.priority
upstreamFound = true
}
}
assert.True(t, localFound, "Local handler should be found")
assert.True(t, upstreamFound, "Upstream handler should be found")
assert.Greater(t, localPriority, upstreamPriority,
"Local handler priority (%d) should be higher than upstream priority (%d)",
localPriority, upstreamPriority)
assert.Equal(t, PriorityLocal, localPriority, "Local handler should use PriorityLocal")
assert.Equal(t, PriorityUpstream, upstreamPriority, "Upstream handler should use PriorityUpstream")
}
func TestLocalResolverPriorityConstants(t *testing.T) {
// Test that priority constants are ordered correctly
assert.Greater(t, PriorityLocal, PriorityDNSRoute, "Local priority should be higher than DNS route")
assert.Greater(t, PriorityLocal, PriorityUpstream, "Local priority should be higher than upstream")
assert.Greater(t, PriorityUpstream, PriorityDefault, "Upstream priority should be higher than default")
// Test that local resolver uses the correct priority
server := &DefaultServer{
localResolver: local.NewResolver(),
}
config := nbdns.Config{
CustomZones: []nbdns.CustomZone{
{
Domain: "local.example.com",
Records: []nbdns.SimpleRecord{
{
Name: "test.local.example.com",
Type: int(dns.TypeA),
Class: nbdns.DefaultClass,
TTL: 300,
RData: "192.168.1.100",
},
},
},
},
}
localMuxUpdates, _, err := server.buildLocalHandlerUpdate(config.CustomZones)
assert.NoError(t, err)
assert.Len(t, localMuxUpdates, 1)
assert.Equal(t, PriorityLocal, localMuxUpdates[0].priority, "Local handler should use PriorityLocal")
assert.Equal(t, "local.example.com", localMuxUpdates[0].domain)
}

View File

@@ -2,6 +2,7 @@ package dns
import ( import (
"context" "context"
"crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"errors" "errors"
@@ -103,19 +104,21 @@ func (u *upstreamResolverBase) Stop() {
// ServeDNS handles a DNS request // ServeDNS handles a DNS request
func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := GenerateRequestID()
logger := log.WithField("request_id", requestID)
var err error var err error
defer func() { defer func() {
u.checkUpstreamFails(err) u.checkUpstreamFails(err)
}() }()
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) logger.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
if r.Extra == nil { if r.Extra == nil {
r.MsgHdr.AuthenticatedData = true r.MsgHdr.AuthenticatedData = true
} }
select { select {
case <-u.ctx.Done(): case <-u.ctx.Done():
log.Tracef("%s has been stopped", u) logger.Tracef("%s has been stopped", u)
return return
default: default:
} }
@@ -132,35 +135,35 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) { if errors.Is(err, context.DeadlineExceeded) || isTimeout(err) {
log.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("upstream %s timed out for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
log.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err) logger.Warnf("failed to query upstream %s for question domain=%s: %s", upstream, r.Question[0].Name, err)
continue continue
} }
if rm == nil || !rm.Response { if rm == nil || !rm.Response {
log.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name) logger.Warnf("no response from upstream %s for question domain=%s", upstream, r.Question[0].Name)
continue continue
} }
u.successCount.Add(1) u.successCount.Add(1)
log.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name) logger.Tracef("took %s to query the upstream %s for question domain=%s", t, upstream, r.Question[0].Name)
if err = w.WriteMsg(rm); err != nil { if err = w.WriteMsg(rm); err != nil {
log.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err) logger.Errorf("failed to write DNS response for question domain=%s: %s", r.Question[0].Name, err)
} }
// count the fails only if they happen sequentially // count the fails only if they happen sequentially
u.failsCount.Store(0) u.failsCount.Store(0)
return return
} }
u.failsCount.Add(1) u.failsCount.Add(1)
log.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name) logger.Errorf("all queries to the %s failed for question domain=%s", u, r.Question[0].Name)
m := new(dns.Msg) m := new(dns.Msg)
m.SetRcode(r, dns.RcodeServerFailure) m.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(m); err != nil { if err := w.WriteMsg(m); err != nil {
log.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err) logger.Errorf("failed to write error response for %s for question domain=%s: %s", u, r.Question[0].Name, err)
} }
} }
@@ -385,3 +388,13 @@ func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, u
return rm, t, nil return rm, t, nil
} }
func GenerateRequestID() string {
bytes := make([]byte, 4)
_, err := rand.Read(bytes)
if err != nil {
log.Errorf("failed to generate request ID: %v", err)
return ""
}
return hex.EncodeToString(bytes)
}

View File

@@ -18,14 +18,20 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors" nberrors "github.com/netbirdio/netbird/client/errors"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
const errResolveFailed = "failed to resolve query for domain=%s: %v" const errResolveFailed = "failed to resolve query for domain=%s: %v"
const upstreamTimeout = 15 * time.Second const upstreamTimeout = 15 * time.Second
type resolver interface {
LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error)
}
type firewaller interface {
UpdateSet(set firewall.Set, prefixes []netip.Prefix) error
}
type DNSForwarder struct { type DNSForwarder struct {
listenAddress string listenAddress string
ttl uint32 ttl uint32
@@ -38,16 +44,18 @@ type DNSForwarder struct {
mutex sync.RWMutex mutex sync.RWMutex
fwdEntries []*ForwarderEntry fwdEntries []*ForwarderEntry
firewall firewall.Manager firewall firewaller
resolver resolver
} }
func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewall.Manager, statusRecorder *peer.Status) *DNSForwarder { func NewDNSForwarder(listenAddress string, ttl uint32, firewall firewaller, statusRecorder *peer.Status) *DNSForwarder {
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl) log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
return &DNSForwarder{ return &DNSForwarder{
listenAddress: listenAddress, listenAddress: listenAddress,
ttl: ttl, ttl: ttl,
firewall: firewall, firewall: firewall,
statusRecorder: statusRecorder, statusRecorder: statusRecorder,
resolver: net.DefaultResolver,
} }
} }
@@ -57,14 +65,17 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// UDP server // UDP server
mux := dns.NewServeMux() mux := dns.NewServeMux()
f.mux = mux f.mux = mux
mux.HandleFunc(".", f.handleDNSQueryUDP)
f.dnsServer = &dns.Server{ f.dnsServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "udp", Net: "udp",
Handler: mux, Handler: mux,
} }
// TCP server // TCP server
tcpMux := dns.NewServeMux() tcpMux := dns.NewServeMux()
f.tcpMux = tcpMux f.tcpMux = tcpMux
tcpMux.HandleFunc(".", f.handleDNSQueryTCP)
f.tcpServer = &dns.Server{ f.tcpServer = &dns.Server{
Addr: f.listenAddress, Addr: f.listenAddress,
Net: "tcp", Net: "tcp",
@@ -87,30 +98,13 @@ func (f *DNSForwarder) Listen(entries []*ForwarderEntry) error {
// return the first error we get (e.g. bind failure or shutdown) // return the first error we get (e.g. bind failure or shutdown)
return <-errCh return <-errCh
} }
func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) { func (f *DNSForwarder) UpdateDomains(entries []*ForwarderEntry) {
f.mutex.Lock() f.mutex.Lock()
defer f.mutex.Unlock() defer f.mutex.Unlock()
if f.mux == nil {
log.Debug("DNS mux is nil, skipping domain update")
f.fwdEntries = entries
return
}
oldDomains := filterDomains(f.fwdEntries)
for _, d := range oldDomains {
f.mux.HandleRemove(d.PunycodeString())
f.tcpMux.HandleRemove(d.PunycodeString())
}
newDomains := filterDomains(entries)
for _, d := range newDomains {
f.mux.HandleFunc(d.PunycodeString(), f.handleDNSQueryUDP)
f.tcpMux.HandleFunc(d.PunycodeString(), f.handleDNSQueryTCP)
}
f.fwdEntries = entries f.fwdEntries = entries
log.Debugf("Updated domains from %v to %v", oldDomains, newDomains) log.Debugf("Updated DNS forwarder with %d domains", len(entries))
} }
func (f *DNSForwarder) Close(ctx context.Context) error { func (f *DNSForwarder) Close(ctx context.Context) error {
@@ -157,22 +151,31 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
return nil return nil
} }
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
// query doesn't match any configured domain
if mostSpecificResId == "" {
resp.Rcode = dns.RcodeRefused
if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS response: %v", err)
}
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
defer cancel() defer cancel()
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain) ips, err := f.resolver.LookupNetIP(ctx, network, domain)
if err != nil { if err != nil {
f.handleDNSError(w, query, resp, domain, err) f.handleDNSError(w, query, resp, domain, err)
return nil return nil
} }
f.updateInternalState(domain, ips) f.updateInternalState(ips, mostSpecificResId, matchingEntries)
f.addIPsToResponse(resp, domain, ips) f.addIPsToResponse(resp, domain, ips)
return resp return resp
} }
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
resp := f.handleDNSQuery(w, query) resp := f.handleDNSQuery(w, query)
if resp == nil { if resp == nil {
return return
@@ -206,9 +209,8 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
} }
} }
func (f *DNSForwarder) updateInternalState(domain string, ips []netip.Addr) { func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
var prefixes []netip.Prefix var prefixes []netip.Prefix
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
if mostSpecificResId != "" { if mostSpecificResId != "" {
for _, ip := range ips { for _, ip := range ips {
var prefix netip.Prefix var prefix netip.Prefix
@@ -339,16 +341,3 @@ func (f *DNSForwarder) getMatchingEntries(domain string) (route.ResID, []*Forwar
return selectedResId, matches return selectedResId, matches
} }
// filterDomains returns a list of normalized domains
func filterDomains(entries []*ForwarderEntry) domain.List {
newDomains := make(domain.List, 0, len(entries))
for _, d := range entries {
if d.Domain == "" {
log.Warn("empty domain in DNS forwarder")
continue
}
newDomains = append(newDomains, domain.Domain(nbdns.NormalizeZone(d.Domain.PunycodeString())))
}
return newDomains
}

View File

@@ -1,11 +1,21 @@
package dnsfwd package dnsfwd
import ( import (
"context"
"fmt"
"net/netip"
"strings"
"testing" "testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -13,7 +23,7 @@ import (
func Test_getMatchingEntries(t *testing.T) { func Test_getMatchingEntries(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
storedMappings map[string]route.ResID // key: domain pattern, value: resId storedMappings map[string]route.ResID
queryDomain string queryDomain string
expectedResId route.ResID expectedResId route.ResID
}{ }{
@@ -44,7 +54,7 @@ func Test_getMatchingEntries(t *testing.T) {
{ {
name: "Wildcard pattern does not match different domain", name: "Wildcard pattern does not match different domain",
storedMappings: map[string]route.ResID{"*.example.com": "res4"}, storedMappings: map[string]route.ResID{"*.example.com": "res4"},
queryDomain: "foo.notexample.com", queryDomain: "foo.example.org",
expectedResId: "", expectedResId: "",
}, },
{ {
@@ -101,3 +111,619 @@ func Test_getMatchingEntries(t *testing.T) {
}) })
} }
} }
type MockFirewall struct {
mock.Mock
}
func (m *MockFirewall) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
args := m.Called(set, prefixes)
return args.Error(0)
}
type MockResolver struct {
mock.Mock
}
func (m *MockResolver) LookupNetIP(ctx context.Context, network, host string) ([]netip.Addr, error) {
args := m.Called(ctx, network, host)
return args.Get(0).([]netip.Addr), args.Error(1)
}
func TestDNSForwarder_SubdomainAccessLogic(t *testing.T) {
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldMatch bool
expectedResID route.ResID
description string
}{
{
name: "exact domain match should be allowed",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Direct match to configured domain should work",
},
{
name: "subdomain access should be restricted",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Subdomain should not be accessible unless explicitly configured",
},
{
name: "wildcard should allow subdomains",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard domains should allow subdomain access",
},
{
name: "wildcard should allow base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should also match the base domain",
},
{
name: "deep subdomain should be restricted",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: false,
expectedResID: "",
description: "Deep subdomains should not be accessible",
},
{
name: "wildcard allows deep subdomains",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldMatch: true,
expectedResID: "test-res-id",
description: "Wildcard should allow deep subdomains",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := &DNSForwarder{}
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
},
}
forwarder.UpdateDomains(entries)
resID, matchingEntries := forwarder.getMatchingEntries(tt.queryDomain)
if tt.shouldMatch {
assert.Equal(t, tt.expectedResID, resID, "Expected matching ResID")
assert.NotEmpty(t, matchingEntries, "Expected matching entries")
t.Logf("✓ Domain %s correctly matches pattern %s", tt.queryDomain, tt.configuredDomain)
} else {
assert.Equal(t, tt.expectedResID, resID, "Expected no ResID match")
assert.Empty(t, matchingEntries, "Expected no matching entries")
t.Logf("✓ Domain %s correctly does NOT match pattern %s", tt.queryDomain, tt.configuredDomain)
}
})
}
}
func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
tests := []struct {
name string
configuredDomain string
queryDomain string
shouldResolve bool
description string
}{
{
name: "configured exact domain resolves",
configuredDomain: "example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Exact match should resolve",
},
{
name: "unauthorized subdomain blocked",
configuredDomain: "example.com",
queryDomain: "mail.example.com",
shouldResolve: false,
description: "Subdomain should be blocked without wildcard",
},
{
name: "wildcard allows subdomain",
configuredDomain: "*.example.com",
queryDomain: "mail.example.com",
shouldResolve: true,
description: "Wildcard should allow subdomain",
},
{
name: "wildcard allows base domain",
configuredDomain: "*.example.com",
queryDomain: "example.com",
shouldResolve: true,
description: "Wildcard should allow base domain",
},
{
name: "unrelated domain blocked",
configuredDomain: "example.com",
queryDomain: "example.org",
shouldResolve: false,
description: "Unrelated domain should be blocked",
},
{
name: "deep subdomain blocked",
configuredDomain: "example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: false,
description: "Deep subdomain should be blocked",
},
{
name: "wildcard allows deep subdomain",
configuredDomain: "*.example.com",
queryDomain: "deep.mail.example.com",
shouldResolve: true,
description: "Wildcard should allow deep subdomain",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
if tt.shouldResolve {
mockFirewall.On("UpdateSet", mock.AnythingOfType("manager.Set"), mock.AnythingOfType("[]netip.Prefix")).Return(nil)
// Mock successful DNS resolution
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.queryDomain)).Return([]netip.Addr{fakeIP}, nil)
}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
d, err := domain.FromString(tt.configuredDomain)
require.NoError(t, err)
entries := []*ForwarderEntry{
{
Domain: d,
ResID: "test-res-id",
Set: firewall.NewDomainSet([]domain.Domain{d}),
},
}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
assert.NotEmpty(t, resp.Answer, "Expected DNS answer records")
time.Sleep(10 * time.Millisecond)
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
} else {
if resp != nil {
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
"Unauthorized domain should not return successful answers")
}
mockFirewall.AssertNotCalled(t, "UpdateSet")
mockResolver.AssertNotCalled(t, "LookupNetIP")
}
})
}
}
func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
tests := []struct {
name string
configuredDomains []string
query string
mockIP string
shouldResolve bool
expectedSetCount int // How many sets should be updated
description string
}{
{
name: "exact domain gets firewall update",
configuredDomains: []string{"example.com"},
query: "example.com",
mockIP: "1.1.1.1",
shouldResolve: true,
expectedSetCount: 1,
description: "Single exact match updates one set",
},
{
name: "wildcard domain gets firewall update",
configuredDomains: []string{"*.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.2",
shouldResolve: true,
expectedSetCount: 1,
description: "Wildcard match updates one set",
},
{
name: "overlapping exact and wildcard both get updates",
configuredDomains: []string{"*.example.com", "mail.example.com"},
query: "mail.example.com",
mockIP: "1.1.1.3",
shouldResolve: true,
expectedSetCount: 2,
description: "Both exact and wildcard sets should be updated",
},
{
name: "unauthorized domain gets no firewall update",
configuredDomains: []string{"example.com"},
query: "mail.example.com",
mockIP: "1.1.1.4",
shouldResolve: false,
expectedSetCount: 0,
description: "No firewall update for unauthorized domains",
},
{
name: "multiple wildcards matching get all updated",
configuredDomains: []string{"*.example.com", "*.sub.example.com"},
query: "test.sub.example.com",
mockIP: "1.1.1.5",
shouldResolve: true,
expectedSetCount: 2,
description: "All matching wildcard sets should be updated",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
// Set up forwarder
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Create entries and track sets
var entries []*ForwarderEntry
sets := make([]firewall.Set, 0)
for i, configDomain := range tt.configuredDomains {
d, err := domain.FromString(configDomain)
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
sets = append(sets, set)
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID(fmt.Sprintf("res-%d", i)),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Set up mocks
if tt.shouldResolve {
fakeIP := netip.MustParseAddr(tt.mockIP)
mockResolver.On("LookupNetIP", mock.Anything, "ip4", dns.Fqdn(tt.query)).
Return([]netip.Addr{fakeIP}, nil).Once()
expectedPrefixes := []netip.Prefix{netip.PrefixFrom(fakeIP, 32)}
// Count how many sets should actually match
updateCount := 0
for i, entry := range entries {
domain := strings.ToLower(tt.query)
pattern := entry.Domain.PunycodeString()
matches := false
if strings.HasPrefix(pattern, "*.") {
baseDomain := strings.TrimPrefix(pattern, "*.")
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
matches = true
}
} else if domain == pattern {
matches = true
}
if matches {
mockFirewall.On("UpdateSet", sets[i], expectedPrefixes).Return(nil).Once()
updateCount++
}
}
assert.Equal(t, tt.expectedSetCount, updateCount,
"Expected %d sets to be updated, but mock expects %d",
tt.expectedSetCount, updateCount)
}
// Execute query
dnsQuery := &dns.Msg{}
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, dnsQuery)
// Verify response
if tt.shouldResolve {
require.NotNil(t, resp, "Expected response for authorized domain")
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.NotEmpty(t, resp.Answer)
} else if resp != nil {
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
"Unauthorized domain should be refused or have no answers")
}
// Verify all mock expectations were met
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
})
}
}
// Test to verify that multiple IPs for one domain result in all prefixes being sent together
func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Configure a single domain
d, err := domain.FromString("example.com")
require.NoError(t, err)
set := firewall.NewDomainSet([]domain.Domain{d})
entries := []*ForwarderEntry{{
Domain: d,
ResID: "test-res",
Set: set,
}}
forwarder.UpdateDomains(entries)
// Mock resolver returns multiple IPs
ips := []netip.Addr{
netip.MustParseAddr("1.1.1.1"),
netip.MustParseAddr("1.1.1.2"),
netip.MustParseAddr("1.1.1.3"),
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
Return(ips, nil).Once()
// Expect ONE UpdateSet call with ALL prefixes
expectedPrefixes := []netip.Prefix{
netip.PrefixFrom(ips[0], 32),
netip.PrefixFrom(ips[1], 32),
netip.PrefixFrom(ips[2], 32),
}
mockFirewall.On("UpdateSet", set, expectedPrefixes).Return(nil).Once()
// Execute query
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
// Verify response contains all IPs
require.NotNil(t, resp)
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
// Verify mocks
mockFirewall.AssertExpectations(t)
mockResolver.AssertExpectations(t)
}
func TestDNSForwarder_ResponseCodes(t *testing.T) {
tests := []struct {
name string
queryType uint16
queryDomain string
configured string
expectedCode int
description string
}{
{
name: "unauthorized domain returns REFUSED",
queryType: dns.TypeA,
queryDomain: "evil.com",
configured: "example.com",
expectedCode: dns.RcodeRefused,
description: "RFC compliant REFUSED for unauthorized queries",
},
{
name: "unsupported query type returns NOTIMP",
queryType: dns.TypeMX,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "RFC compliant NOTIMP for unsupported types",
},
{
name: "CNAME query returns NOTIMP",
queryType: dns.TypeCNAME,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "CNAME queries not supported",
},
{
name: "TXT query returns NOTIMP",
queryType: dns.TypeTXT,
queryDomain: "example.com",
configured: "example.com",
expectedCode: dns.RcodeNotImplemented,
description: "TXT queries not supported",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
d, err := domain.FromString(tt.configured)
require.NoError(t, err)
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
query := &dns.Msg{}
query.SetQuestion(dns.Fqdn(tt.queryDomain), tt.queryType)
// Capture the written response
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
_ = forwarder.handleDNSQuery(mockWriter, query)
// Check the response written to the writer
require.NotNil(t, writtenResp, "Expected response to be written")
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
})
}
}
func TestDNSForwarder_TCPTruncation(t *testing.T) {
// Test that large UDP responses are truncated with TC bit set
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
forwarder.resolver = mockResolver
d, _ := domain.FromString("example.com")
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res"}}
forwarder.UpdateDomains(entries)
// Mock many IPs to create a large response
var manyIPs []netip.Addr
for i := 0; i < 100; i++ {
manyIPs = append(manyIPs, netip.MustParseAddr(fmt.Sprintf("1.1.1.%d", i%256)))
}
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").Return(manyIPs, nil)
// Query without EDNS0
query := &dns.Msg{}
query.SetQuestion("example.com.", dns.TypeA)
var writtenResp *dns.Msg
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writtenResp = m
return nil
},
}
forwarder.handleDNSQueryUDP(mockWriter, query)
require.NotNil(t, writtenResp)
assert.True(t, writtenResp.Truncated, "Large response should be truncated")
assert.LessOrEqual(t, writtenResp.Len(), dns.MinMsgSize, "Response should fit in minimum UDP size")
}
func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
// Test complex overlapping pattern scenarios
mockFirewall := &MockFirewall{}
mockResolver := &MockResolver{}
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
forwarder.resolver = mockResolver
// Set up complex overlapping patterns
patterns := []string{
"*.example.com", // Matches all subdomains
"*.mail.example.com", // More specific wildcard
"smtp.mail.example.com", // Exact match
"example.com", // Base domain
}
var entries []*ForwarderEntry
sets := make(map[string]firewall.Set)
for _, pattern := range patterns {
d, _ := domain.FromString(pattern)
set := firewall.NewDomainSet([]domain.Domain{d})
sets[pattern] = set
entries = append(entries, &ForwarderEntry{
Domain: d,
ResID: route.ResID("res-" + pattern),
Set: set,
})
}
forwarder.UpdateDomains(entries)
// Test smtp.mail.example.com - should match 3 patterns
fakeIP := netip.MustParseAddr("1.2.3.4")
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "smtp.mail.example.com.").Return([]netip.Addr{fakeIP}, nil)
expectedPrefix := netip.PrefixFrom(fakeIP, 32)
// All three matching patterns should get firewall updates
mockFirewall.On("UpdateSet", sets["smtp.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.mail.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
mockFirewall.On("UpdateSet", sets["*.example.com"], []netip.Prefix{expectedPrefix}).Return(nil)
query := &dns.Msg{}
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
mockWriter := &test.MockResponseWriter{}
resp := forwarder.handleDNSQuery(mockWriter, query)
require.NotNil(t, resp)
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
// Verify all three sets were updated
mockFirewall.AssertExpectations(t)
// Verify the most specific ResID was selected
// (exact match should win over wildcards)
resID, matches := forwarder.getMatchingEntries("smtp.mail.example.com")
assert.Equal(t, route.ResID("res-smtp.mail.example.com"), resID)
assert.Len(t, matches, 3, "Should match 3 patterns")
}
func TestDNSForwarder_EmptyQuery(t *testing.T) {
// Test handling of malformed query with no questions
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})
query := &dns.Msg{}
// Don't set any question
writeCalled := false
mockWriter := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
writeCalled = true
return nil
},
}
resp := forwarder.handleDNSQuery(mockWriter, query)
assert.Nil(t, resp, "Should return nil for empty query")
assert.False(t, writeCalled, "Should not write response for empty query")
}

View File

@@ -1527,6 +1527,7 @@ func (e *Engine) newWgIface() (*iface.WGIface, error) {
MTU: iface.DefaultMTU, MTU: iface.DefaultMTU,
TransportNet: transportNet, TransportNet: transportNet,
FilterFn: e.addrViaRoutes, FilterFn: e.addrViaRoutes,
DisableDNS: e.config.DisableDNS,
} }
switch runtime.GOOS { switch runtime.GOOS {

View File

@@ -575,13 +575,12 @@ func (d *Status) UpdatePeerFQDN(peerPubKey, fqdn string) error {
// FinishPeerListModifications this event invoke the notification // FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() { func (d *Status) FinishPeerListModifications() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification { if !d.peerListChangedForNotification {
d.mux.Unlock()
return return
} }
d.peerListChangedForNotification = false d.peerListChangedForNotification = false
d.mux.Unlock()
d.notifyPeerListChanged() d.notifyPeerListChanged()

View File

@@ -144,15 +144,18 @@ func (d *DnsInterceptor) RemoveAllowedIPs() error {
// ServeDNS implements the dns.Handler interface // ServeDNS implements the dns.Handler interface
func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
requestID := nbdns.GenerateRequestID()
logger := log.WithField("request_id", requestID)
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
log.Tracef("received DNS request for domain=%s type=%v class=%v", logger.Tracef("received DNS request for domain=%s type=%v class=%v",
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass) r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
// pass if non A/AAAA query // pass if non A/AAAA query
if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA { if r.Question[0].Qtype != dns.TypeA && r.Question[0].Qtype != dns.TypeAAAA {
d.continueToNextHandler(w, r, "non A/AAAA query") d.continueToNextHandler(w, r, logger, "non A/AAAA query")
return return
} }
@@ -161,19 +164,19 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
d.mu.RUnlock() d.mu.RUnlock()
if peerKey == "" { if peerKey == "" {
d.writeDNSError(w, r, "no current peer key") d.writeDNSError(w, r, logger, "no current peer key")
return return
} }
upstreamIP, err := d.getUpstreamIP(peerKey) upstreamIP, err := d.getUpstreamIP(peerKey)
if err != nil { if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("get upstream IP: %v", err)) d.writeDNSError(w, r, logger, fmt.Sprintf("get upstream IP: %v", err))
return return
} }
client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout) client, err := nbdns.GetClientPrivate(d.wgInterface.Address().IP, d.wgInterface.Name(), nbdns.UpstreamTimeout)
if err != nil { if err != nil {
d.writeDNSError(w, r, fmt.Sprintf("create DNS client: %v", err)) d.writeDNSError(w, r, logger, fmt.Sprintf("create DNS client: %v", err))
return return
} }
@@ -184,9 +187,9 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort) upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), dnsfwd.ListenPort)
reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream) reply, _, err := nbdns.ExchangeWithFallback(context.TODO(), client, r, upstream)
if err != nil { if err != nil {
log.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err) logger.Errorf("failed to exchange DNS request with %s (%s) for domain=%s: %v", upstreamIP.String(), peerKey, r.Question[0].Name, err)
if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil { if err := w.WriteMsg(&dns.Msg{MsgHdr: dns.MsgHdr{Rcode: dns.RcodeServerFailure, Id: r.Id}}); err != nil {
log.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
return return
} }
@@ -196,34 +199,34 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
answer = reply.Answer answer = reply.Answer
} }
log.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer) logger.Tracef("upstream %s (%s) DNS response for domain=%s answers=%v", upstreamIP.String(), peerKey, r.Question[0].Name, answer)
reply.Id = r.Id reply.Id = r.Id
if err := d.writeMsg(w, reply); err != nil { if err := d.writeMsg(w, reply); err != nil {
log.Errorf("failed writing DNS response: %v", err) logger.Errorf("failed writing DNS response: %v", err)
} }
} }
func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, reason string) { func (d *DnsInterceptor) writeDNSError(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
log.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason) logger.Warnf("failed to query upstream for domain=%s: %s", r.Question[0].Name, reason)
resp := new(dns.Msg) resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeServerFailure) resp.SetRcode(r, dns.RcodeServerFailure)
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed to write DNS error response: %v", err) logger.Errorf("failed to write DNS error response: %v", err)
} }
} }
// continueToNextHandler signals the handler chain to try the next handler // continueToNextHandler signals the handler chain to try the next handler
func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, reason string) { func (d *DnsInterceptor) continueToNextHandler(w dns.ResponseWriter, r *dns.Msg, logger *log.Entry, reason string) {
log.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason) logger.Tracef("continuing to next handler for domain=%s reason=%s", r.Question[0].Name, reason)
resp := new(dns.Msg) resp := new(dns.Msg)
resp.SetRcode(r, dns.RcodeNameError) resp.SetRcode(r, dns.RcodeNameError)
// Set Zero bit to signal handler chain to continue // Set Zero bit to signal handler chain to continue
resp.MsgHdr.Zero = true resp.MsgHdr.Zero = true
if err := w.WriteMsg(resp); err != nil { if err := w.WriteMsg(resp); err != nil {
log.Errorf("failed writing DNS continue response: %v", err) logger.Errorf("failed writing DNS continue response: %v", err)
} }
} }

View File

@@ -15,7 +15,7 @@ import (
// MockManager is the mock instance of a route manager // MockManager is the mock instance of a route manager
type MockManager struct { type MockManager struct {
ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap) ClassifyRoutesFunc func(routes []*route.Route) (map[route.ID]*route.Route, route.HAMap)
UpdateRoutesFunc func (updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error UpdateRoutesFunc func(updateSerial uint64, serverRoutes map[route.ID]*route.Route, clientRoutes route.HAMap, useNewDNSRoute bool) error
TriggerSelectionFunc func(haMap route.HAMap) TriggerSelectionFunc func(haMap route.HAMap)
GetRouteSelectorFunc func() *routeselector.RouteSelector GetRouteSelectorFunc func() *routeselector.RouteSelector
GetClientRoutesFunc func() route.HAMap GetClientRoutesFunc func() route.HAMap

View File

@@ -32,7 +32,6 @@ func (n *Notifier) SetListener(listener listener.NetworkChangeListener) {
func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) { func (n *Notifier) SetInitialClientRoutes(clientRoutes []*route.Route) {
nets := make([]string, 0) nets := make([]string, 0)
for _, r := range clientRoutes { for _, r := range clientRoutes {
// filter out domain routes
if r.IsDynamic() { if r.IsDynamic() {
continue continue
} }
@@ -46,30 +45,27 @@ func (n *Notifier) OnNewRoutes(idMap route.HAMap) {
if runtime.GOOS != "android" { if runtime.GOOS != "android" {
return return
} }
newNets := make([]string, 0)
var newNets []string
for _, routes := range idMap { for _, routes := range idMap {
for _, r := range routes { for _, r := range routes {
if r.IsDynamic() {
continue
}
newNets = append(newNets, r.Network.String()) newNets = append(newNets, r.Network.String())
} }
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.initialRouteRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }
// OnNewPrefixes is called from iOS only
func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) { func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
newNets := make([]string, 0) newNets := make([]string, 0)
for _, prefix := range prefixes { for _, prefix := range prefixes {
@@ -77,19 +73,11 @@ func (n *Notifier) OnNewPrefixes(prefixes []netip.Prefix) {
} }
sort.Strings(newNets) sort.Strings(newNets)
switch runtime.GOOS { if !n.hasDiff(n.routeRanges, newNets) {
case "android": return
if !n.hasDiff(n.initialRouteRanges, newNets) {
return
}
default:
if !n.hasDiff(n.routeRanges, newNets) {
return
}
} }
n.routeRanges = newNets n.routeRanges = newNets
n.notify() n.notify()
} }

View File

@@ -59,16 +59,16 @@ type Info struct {
Environment Environment Environment Environment
Files []File // for posture checks Files []File // for posture checks
RosenpassEnabled bool RosenpassEnabled bool
RosenpassPermissive bool RosenpassPermissive bool
ServerSSHAllowed bool ServerSSHAllowed bool
DisableClientRoutes bool DisableClientRoutes bool
DisableServerRoutes bool DisableServerRoutes bool
DisableDNS bool DisableDNS bool
DisableFirewall bool DisableFirewall bool
BlockLANAccess bool BlockLANAccess bool
BlockInbound bool BlockInbound bool
LazyConnectionEnabled bool LazyConnectionEnabled bool
} }

View File

@@ -572,7 +572,7 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool var systrayIconState bool
switch { switch {
case status.Status == string(internal.StatusConnected) && !s.mUp.Disabled(): case status.Status == string(internal.StatusConnected):
s.connected = true s.connected = true
s.sendNotification = true s.sendNotification = true
if s.isUpdateIconActive { if s.isUpdateIconActive {

View File

@@ -12,6 +12,8 @@ import (
"fyne.io/fyne/v2" "fyne.io/fyne/v2"
"fyne.io/systray" "fyne.io/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/version"
) )
type eventHandler struct { type eventHandler struct {
@@ -143,7 +145,7 @@ func (h *eventHandler) handleGitHubClick() {
} }
func (h *eventHandler) handleUpdateClick() { func (h *eventHandler) handleUpdateClick() {
if err := openURL("https://netbird.io/download"); err != nil { if err := openURL(version.DownloadUrl()); err != nil {
log.Errorf("failed to open download URL: %v", err) log.Errorf("failed to open download URL: %v", err)
} }
} }

View File

@@ -24,6 +24,7 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/formatter/hook"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache" nbcache "github.com/netbirdio/netbird/management/server/cache"
@@ -101,6 +102,8 @@ type DefaultAccountManager struct {
accountUpdateLocks sync.Map accountUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64 updateAccountPeersBufferInterval atomic.Int64
loginFilter *loginFilter
} }
// getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups.
@@ -194,6 +197,7 @@ func BuildManager(
proxyController: proxyController, proxyController: proxyController,
settingsManager: settingsManager, settingsManager: settingsManager,
permissionsManager: permissionsManager, permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
} }
am.startWarmup(ctx) am.startWarmup(ctx)
@@ -409,14 +413,15 @@ func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.C
event = activity.AccountPeerLoginExpirationDisabled event = activity.AccountPeerLoginExpirationDisabled
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
} else { } else {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.schedulePeerLoginExpiration(ctx, accountID)
} }
am.StoreEvent(ctx, userID, accountID, accountID, event, nil) am.StoreEvent(ctx, userID, accountID, accountID, event, nil)
} }
if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration { if oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
} }
} }
@@ -454,6 +459,10 @@ func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.
func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) { func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, accountID string) func() (time.Duration, bool) {
return func() (time.Duration, bool) { return func() (time.Duration, bool) {
//nolint
ctx := context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
//nolint
ctx = context.WithValue(ctx, hook.ExecutionContextKey, fmt.Sprintf("%s-PEER-EXPIRATION", hook.SystemSource))
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
@@ -478,8 +487,11 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(ctx context.Context, acc
} }
} }
func (am *DefaultAccountManager) checkAndSchedulePeerLoginExpiration(ctx context.Context, accountID string) { func (am *DefaultAccountManager) schedulePeerLoginExpiration(ctx context.Context, accountID string) {
am.peerLoginExpiry.Cancel(ctx, []string{accountID}) if am.peerLoginExpiry.IsSchedulerRunning(accountID) {
log.WithContext(ctx).Tracef("peer login expiration job for account %s is already scheduled", accountID)
return
}
if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok { if nextRun, ok := am.getNextPeerExpiration(ctx, accountID); ok {
go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID)) go am.peerLoginExpiry.Schedule(ctx, nextRun, accountID, am.peerLoginExpirationJob(ctx, accountID))
} }
@@ -1527,6 +1539,10 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth nbcontext.U
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
} }
func (am *DefaultAccountManager) AllowSync(wgPubKey, metahash string) bool {
return true
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now() start := time.Now()
defer func() { defer func() {
@@ -1548,6 +1564,9 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
} }
metahash := metaHash(meta, realIP.String())
am.loginFilter.addLogin(peerPubKey, metahash)
return peer, netMap, postureChecks, nil return peer, netMap, postureChecks, nil
} }
@@ -1561,7 +1580,6 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
if err != nil { if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
} }
return nil return nil
} }

View File

@@ -117,4 +117,5 @@ type Manager interface {
UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error)
GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error)
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
AllowSync(string, string) bool
} }

View File

@@ -1862,11 +1862,8 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
wg := &sync.WaitGroup{} wg := &sync.WaitGroup{}
wg.Add(2) wg.Add(1)
manager.peerLoginExpiry = &MockScheduler{ manager.peerLoginExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {
wg.Done()
},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) { ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
wg.Done() wg.Done()
}, },

View File

@@ -664,15 +664,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
func (am *DefaultAccountManager) anyGroupHasPeers(account *types.Account, groupIDs []string) bool {
for _, groupID := range groupIDs {
if group, exists := account.Groups[groupID]; exists && group.HasPeers() {
return true
}
}
return false
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs)

View File

@@ -166,7 +166,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID) ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
realIP := getRealIP(ctx) realIP := getRealIP(ctx)
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, realIP.String()) sip := realIP.String()
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s] [%s]", req.WgPubKey, sip, metaHash(extractPeerMeta(ctx, syncReq.GetMeta()), sip))
if syncReq.GetMeta() == nil { if syncReq.GetMeta() == nil {
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP) log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)

View File

@@ -0,0 +1,113 @@
package server
import (
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
const (
loginFilterSize = 100_000 // Size of the login filter map, making it large enough for a future
filterTimeout = 5 * time.Minute // Duration to secure the previous login information in the filter
loggingLimit = 100
loggingLimitOnePeer = 30
loggingTresholdOnePeer = 5 * time.Minute
)
type loginFilter struct {
mu sync.RWMutex
logged map[string]metahash
}
type metahash struct {
hashes map[string]struct{}
counter int
start time.Time
}
func newLoginFilter() *loginFilter {
return &loginFilter{
logged: make(map[string]metahash, loginFilterSize),
}
}
func (l *loginFilter) addLogin(wgPubKey, metaHash string) {
l.mu.Lock()
defer l.mu.Unlock()
mh, ok := l.logged[wgPubKey]
if !ok {
mh = metahash{
hashes: make(map[string]struct{}, loggingLimit),
start: time.Now(),
}
}
mh.hashes[metaHash] = struct{}{}
mh.counter++
if mh.counter >= loggingLimit && mh.counter%loggingLimit == 0 && len(mh.hashes) > 1 {
log.WithFields(log.Fields{
"wgPubKey": wgPubKey,
"number of different hashes": len(mh.hashes),
"elapsed time for number of attempts": time.Since(mh.start),
"number of syncs": mh.counter,
}).Info(mh.prepareHashes())
} else if mh.counter%loggingLimitOnePeer == 0 && time.Since(mh.start) > loggingTresholdOnePeer && len(mh.hashes) == 1 {
log.WithFields(log.Fields{
"wgPubKey": wgPubKey,
"elapsed time for number of attempts": time.Since(mh.start),
"number of syncs": mh.counter,
}).Info(mh.prepareHashes())
mh.start = time.Now()
}
l.logged[wgPubKey] = mh
}
func (m *metahash) prepareHashes() string {
var sb strings.Builder
for hash := range m.hashes {
sb.WriteString(hash)
sb.WriteString(", ")
}
return sb.String()
}
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
b.WriteString(meta.WtVersion)
b.WriteByte('|')
b.WriteString(meta.OSVersion)
b.WriteByte('|')
b.WriteString(meta.KernelVersion)
b.WriteByte('|')
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
b.WriteByte('|')
b.WriteString(mac)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}

View File

@@ -119,6 +119,8 @@ type MockAccountManager struct {
GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
AllowSyncFunc func(string, string) bool
} }
func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) { func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
@@ -890,3 +892,7 @@ func (am *MockAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth n
} }
return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetCurrentUserInfo is not implemented")
} }
func (am *MockAccountManager) AllowSync(_, _ string) bool {
return true
}

View File

@@ -133,7 +133,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.schedulePeerLoginExpiration(ctx, accountID)
} }
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
@@ -296,7 +296,8 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain)) am.StoreEvent(ctx, userID, peer.IP.String(), accountID, event, peer.EventMeta(dnsDomain))
if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { if peer.AddedWithSSOLogin() && peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.peerLoginExpiry.Cancel(ctx, []string{accountID})
am.schedulePeerLoginExpiration(ctx, accountID)
} }
} }

View File

@@ -1579,7 +1579,6 @@ func Test_LoginPeer(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
setupKey string setupKey string
wireGuardPubKey string
expectExtraDNSLabelsMismatch bool expectExtraDNSLabelsMismatch bool
extraDNSLabels []string extraDNSLabels []string
expectLoginError bool expectLoginError bool

View File

@@ -4,19 +4,19 @@ import (
"context" "context"
"fmt" "fmt"
"net/netip" "net/netip"
"slices"
"unicode/utf8" "unicode/utf8"
"github.com/rs/xid" "github.com/rs/xid"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto" "github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -30,13 +30,19 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, string(routeID), accountID) return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account *types.Account, peerID string, routeID route.ID, peerGroupIDs []string, prefix netip.Prefix, domains domain.List) error { func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction store.Store, accountID string, checkRoute *route.Route, groupsMap map[string]*types.Group) error {
// routes can have both peer and peer_groups // routes can have both peer and peer_groups
routesWithPrefix := account.GetRoutesByPrefixOrDomains(prefix, domains) prefix := checkRoute.Network
domains := checkRoute.Domains
routesWithPrefix, err := getRoutesByPrefixOrDomains(ctx, transaction, accountID, prefix, domains)
if err != nil {
return err
}
// lets remember all the peers and the peer groups from routesWithPrefix // lets remember all the peers and the peer groups from routesWithPrefix
seenPeers := make(map[string]bool) seenPeers := make(map[string]bool)
@@ -45,18 +51,24 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
for _, prefixRoute := range routesWithPrefix { for _, prefixRoute := range routesWithPrefix {
// we skip route(s) with the same network ID as we want to allow updating of the existing route // we skip route(s) with the same network ID as we want to allow updating of the existing route
// when creating a new route routeID is newly generated so nothing will be skipped // when creating a new route routeID is newly generated so nothing will be skipped
if routeID == prefixRoute.ID { if checkRoute.ID == prefixRoute.ID {
continue continue
} }
if prefixRoute.Peer != "" { if prefixRoute.Peer != "" {
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups)
if err != nil {
return err
}
for _, groupID := range prefixRoute.PeerGroups { for _, groupID := range prefixRoute.PeerGroups {
seenPeerGroups[groupID] = true seenPeerGroups[groupID] = true
group := account.GetGroup(groupID) group, ok := peerGroupsMap[groupID]
if group == nil { if !ok || group == nil {
return status.Errorf( return status.Errorf(
status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist", status.InvalidArgument, "failed to add route with %s - peer group %s doesn't exist",
getRouteDescriptor(prefix, domains), groupID, getRouteDescriptor(prefix, domains), groupID,
@@ -69,12 +81,13 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
} }
if peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
peer := account.GetPeer(peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID)
if peer == nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
if _, ok := seenPeers[peerID]; ok { if _, ok := seenPeers[peerID]; ok {
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID) "failed to add route with %s - peer %s already has this route", getRouteDescriptor(prefix, domains), peerID)
@@ -82,9 +95,8 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that peerGroupIDs are not in any route peerGroups list // check that peerGroupIDs are not in any route peerGroups list
for _, groupID := range peerGroupIDs { for _, groupID := range checkRoute.PeerGroups {
group := account.GetGroup(groupID) // we validated the group existence before entering this function, no need to check again. group := groupsMap[groupID] // we validated the group existence before entering this function, no need to check again.
if _, ok := seenPeerGroups[groupID]; ok { if _, ok := seenPeerGroups[groupID]; ok {
return status.Errorf( return status.Errorf(
status.AlreadyExists, "failed to add route with %s - peer group %s already has this route", status.AlreadyExists, "failed to add route with %s - peer group %s already has this route",
@@ -92,12 +104,18 @@ func (am *DefaultAccountManager) checkRoutePrefixOrDomainsExistForPeers(account
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers)
if err != nil {
return err
}
for _, id := range group.Peers { for _, id := range group.Peers {
if _, ok := seenPeers[id]; ok { if _, ok := seenPeers[id]; ok {
peer := account.GetPeer(id) peer, ok := peersMap[id]
if peer == nil { if !ok || peer == nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", id)
} }
return status.Errorf(status.AlreadyExists, return status.Errorf(status.AlreadyExists,
"failed to add route with %s - peer %s from the group %s already has this route", "failed to add route with %s - peer %s from the group %s already has this route",
getRouteDescriptor(prefix, domains), peer.Name, group.Name) getRouteDescriptor(prefix, domains), peer.Name, group.Name)
@@ -128,97 +146,58 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
if len(domains) > 0 && prefix.IsValid() { if len(domains) > 0 && prefix.IsValid() {
return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return nil, status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
if len(domains) == 0 && !prefix.IsValid() { var newRoute *route.Route
return nil, status.Errorf(status.InvalidArgument, "invalid Prefix") var updateAccountPeers bool
}
if len(domains) > 0 { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
prefix = getPlaceholderIP() newRoute = &route.Route{
} ID: route.ID(xid.New().String()),
AccountID: accountID,
if peerID != "" && len(peerGroupIDs) != 0 { Network: prefix,
return nil, status.Errorf( Domains: domains,
status.InvalidArgument, KeepRoute: keepRoute,
"peer with ID %s and peers group %s should not be provided at the same time", NetID: netID,
peerID, peerGroupIDs) Description: description,
} Peer: peerID,
PeerGroups: peerGroupIDs,
var newRoute route.Route NetworkType: networkType,
newRoute.ID = route.ID(xid.New().String()) Masquerade: masquerade,
Metric: metric,
if len(peerGroupIDs) > 0 { Enabled: enabled,
err = validateGroups(peerGroupIDs, account.Groups) Groups: groups,
if err != nil { AccessControlGroups: accessControlGroupIDs,
return nil, err
} }
}
if len(accessControlGroupIDs) > 0 { if err = validateRoute(ctx, transaction, accountID, newRoute); err != nil {
err = validateGroups(accessControlGroupIDs, account.Groups) return err
if err != nil {
return nil, err
} }
}
err = am.checkRoutePrefixOrDomainsExistForPeers(account, peerID, newRoute.ID, peerGroupIDs, prefix, domains) updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, newRoute)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if metric < route.MinMetric || metric > route.MaxMetric {
return nil, status.Errorf(status.InvalidArgument, "metric should be between %d and %d", route.MinMetric, route.MaxMetric)
}
if utf8.RuneCountInString(string(netID)) > route.MaxNetIDChar || netID == "" {
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
}
err = validateGroups(groups, account.Groups)
if err != nil {
return nil, err
}
newRoute.Peer = peerID
newRoute.PeerGroups = peerGroupIDs
newRoute.Network = prefix
newRoute.Domains = domains
newRoute.NetworkType = networkType
newRoute.Description = description
newRoute.NetID = netID
newRoute.Masquerade = masquerade
newRoute.Metric = metric
newRoute.Enabled = enabled
newRoute.Groups = groups
newRoute.KeepRoute = keepRoute
newRoute.AccessControlGroups = accessControlGroupIDs
if account.Routes == nil {
account.Routes = make(map[route.ID]*route.Route)
}
account.Routes[newRoute.ID] = &newRoute
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return nil, err
}
if am.isRouteChangeAffectPeers(account, &newRoute) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
return &newRoute, nil if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return newRoute, nil
} }
// SaveRoute saves route // SaveRoute saves route
@@ -226,6 +205,115 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
return err
}
oldRoute, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeToSave.ID))
if err != nil {
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.SaveRoute(ctx, store.LockingStrengthUpdate, routeToSave)
})
if err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
return transaction.DeleteRoute(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
})
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
}
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
if routeToSave == nil { if routeToSave == nil {
return status.Errorf(status.InvalidArgument, "route provided is nil") return status.Errorf(status.InvalidArgument, "route provided is nil")
} }
@@ -238,19 +326,6 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar) return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
} }
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Update)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() { if len(routeToSave.Domains) > 0 && routeToSave.Network.IsValid() {
return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time") return status.Errorf(status.InvalidArgument, "domains and network should not be provided at the same time")
} }
@@ -267,96 +342,39 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time") return status.Errorf(status.InvalidArgument, "peer with ID and peer groups should not be provided at the same time")
} }
groupsMap, err := validateRouteGroups(ctx, transaction, accountID, routeToSave)
if err != nil {
return err
}
return checkRoutePrefixOrDomainsExistForPeers(ctx, transaction, accountID, routeToSave, groupsMap)
}
// validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate)
if err != nil {
return nil, err
}
if len(routeToSave.PeerGroups) > 0 { if len(routeToSave.PeerGroups) > 0 {
err = validateGroups(routeToSave.PeerGroups, account.Groups) if err = validateGroups(routeToSave.PeerGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
if len(routeToSave.AccessControlGroups) > 0 { if len(routeToSave.AccessControlGroups) > 0 {
err = validateGroups(routeToSave.AccessControlGroups, account.Groups) if err = validateGroups(routeToSave.AccessControlGroups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
} }
err = am.checkRoutePrefixOrDomainsExistForPeers(account, routeToSave.Peer, routeToSave.ID, routeToSave.Copy().PeerGroups, routeToSave.Network, routeToSave.Domains) if err = validateGroups(routeToSave.Groups, groupsMap); err != nil {
if err != nil { return nil, err
return err
} }
err = validateGroups(routeToSave.Groups, account.Groups) return groupsMap, nil
if err != nil {
return err
}
oldRoute := account.Routes[routeToSave.ID]
account.Routes[routeToSave.ID] = routeToSave
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
if am.isRouteChangeAffectPeers(account, oldRoute) || am.isRouteChangeAffectPeers(account, routeToSave) {
am.UpdateAccountPeers(ctx, accountID)
}
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
return nil
}
// DeleteRoute deletes route with routeID
func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID string, routeID route.ID, userID string) error {
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock()
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Delete)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}
routy := account.Routes[routeID]
if routy == nil {
return status.Errorf(status.NotFound, "route with ID %s doesn't exist", routeID)
}
delete(account.Routes, routeID)
account.Network.IncSerial()
if err = am.Store.SaveAccount(ctx, account); err != nil {
return err
}
am.StoreEvent(ctx, userID, string(routy.ID), accountID, activity.RouteRemoved, routy.EventMeta())
if am.isRouteChangeAffectPeers(account, routy) {
am.UpdateAccountPeers(ctx, accountID)
}
return nil
}
// ListRoutes returns a list of routes from account
func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, userID string) ([]*route.Route, error) {
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Routes, operations.Read)
if err != nil {
return nil, status.NewPermissionValidationError(err)
}
if !allowed {
return nil, status.NewPermissionDeniedError()
}
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
} }
func toProtocolRoute(route *route.Route) *proto.Route { func toProtocolRoute(route *route.Route) *proto.Route {
@@ -455,8 +473,40 @@ func getProtoPortInfo(rule *types.RouteFirewallRule) *proto.PortInfo {
return &portInfo return &portInfo
} }
// isRouteChangeAffectPeers checks if a given route affects peers by determining // areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers // if it has a routing peer, distribution, or peer groups that include peers.
func (am *DefaultAccountManager) isRouteChangeAffectPeers(account *types.Account, route *route.Route) bool { func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
return am.anyGroupHasPeers(account, route.Groups) || am.anyGroupHasPeers(account, route.PeerGroups) || route.Peer != "" if route.Peer != "" {
return true, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return nil, err
}
routes := make([]*route.Route, 0)
for _, r := range accountRoutes {
dynamic := r.IsDynamic()
if dynamic && r.Domains.PunycodeString() == domains.PunycodeString() ||
!dynamic && r.Network.String() == prefix.String() {
routes = append(routes, r)
}
}
return routes, nil
} }

View File

@@ -12,6 +12,7 @@ import (
type Scheduler interface { type Scheduler interface {
Cancel(ctx context.Context, IDs []string) Cancel(ctx context.Context, IDs []string)
Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) Schedule(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool))
IsSchedulerRunning(ID string) bool
} }
// MockScheduler is a mock implementation of Scheduler // MockScheduler is a mock implementation of Scheduler
@@ -26,7 +27,7 @@ func (mock *MockScheduler) Cancel(ctx context.Context, IDs []string) {
mock.CancelFunc(ctx, IDs) mock.CancelFunc(ctx, IDs)
return return
} }
log.WithContext(ctx).Errorf("MockScheduler doesn't have Cancel function defined ") log.WithContext(ctx).Warnf("MockScheduler doesn't have Cancel function defined ")
} }
// Schedule mocks the Schedule function of the Scheduler interface // Schedule mocks the Schedule function of the Scheduler interface
@@ -35,7 +36,13 @@ func (mock *MockScheduler) Schedule(ctx context.Context, in time.Duration, ID st
mock.ScheduleFunc(ctx, in, ID, job) mock.ScheduleFunc(ctx, in, ID, job)
return return
} }
log.WithContext(ctx).Errorf("MockScheduler doesn't have Schedule function defined") log.WithContext(ctx).Warnf("MockScheduler doesn't have Schedule function defined")
}
func (mock *MockScheduler) IsSchedulerRunning(ID string) bool {
// MockScheduler does not implement IsSchedulerRunning, so we return false
log.Warnf("MockScheduler doesn't have IsSchedulerRunning function defined")
return false
} }
// DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them. // DefaultScheduler is a generic structure that allows to schedule jobs (functions) to run in the future and cancel them.
@@ -124,3 +131,11 @@ func (wm *DefaultScheduler) Schedule(ctx context.Context, in time.Duration, ID s
}() }()
} }
// IsSchedulerRunning checks if a job with the provided ID is scheduled to run
func (wm *DefaultScheduler) IsSchedulerRunning(ID string) bool {
wm.mu.Lock()
defer wm.mu.Unlock()
_, ok := wm.jobs[ID]
return ok
}

View File

@@ -42,7 +42,10 @@ const (
// Type is a type of the Error // Type is a type of the Error
type Type int32 type Type int32
var ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found") var (
ErrExtraSettingsNotFound = fmt.Errorf("extra settings not found")
ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
)
// Error is an internal error // Error is an internal error
type Error struct { type Error struct {
@@ -227,3 +230,7 @@ func NewUserRoleNotFoundError(role string) error {
func NewOperationNotFoundError(operation operations.Operation) error { func NewOperationNotFoundError(operation operations.Operation) error {
return Errorf(NotFound, "operation: %s not found", operation) return Errorf(NotFound, "operation: %s not found", operation)
} }
func NewRouteNotFoundError(routeID string) error {
return Errorf(NotFound, "route: %s not found", routeID)
}

View File

@@ -23,8 +23,6 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
"gorm.io/gorm/logger" "gorm.io/gorm/logger"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -34,6 +32,7 @@ import (
"github.com/netbirdio/netbird/management/server/status" "github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
) )
@@ -1968,12 +1967,58 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db, lockStrength, accountID) var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store")
}
return routes, nil
} }
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
return getRecordByID[route.Route](s.db, lockStrength, routeID, accountID) var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID)
}
log.WithContext(ctx).Errorf("failed to get route from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get route from store")
}
return route, nil
}
// SaveRoute saves a route to the database.
func (s *SqlStore) SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(route)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to save route to the store: %s", err)
return status.Errorf(status.Internal, "failed to save route to store")
}
return nil
}
// DeleteRoute deletes a route from the database.
func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).
Delete(&route.Route{}, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to delete route from the store: %s", err)
return status.Errorf(status.Internal, "failed to delete route from store")
}
if result.RowsAffected == 0 {
return status.NewRouteNotFoundError(routeID)
}
return nil
} }
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
@@ -2104,49 +2149,6 @@ func (s *SqlStore) DeleteNameServerGroup(ctx context.Context, lockStrength Locki
return nil return nil
} }
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record []T
result := tx.Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
return nil, status.Errorf(status.Internal, "failed to get account %ss from store: %v", recordType, err)
}
return record, nil
}
// getRecordByID retrieves a record by its ID and account ID from the database.
func getRecordByID[T any](db *gorm.DB, lockStrength LockingStrength, recordID, accountID string) (*T, error) {
tx := db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var record T
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&record, accountAndIDQueryCondition, accountID, recordID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "%s not found", recordType)
}
return nil, status.Errorf(status.Internal, "failed to get %s from store: %v", recordType, err)
}
return &record, nil
}
// SaveDNSSettings saves the DNS settings to the store. // SaveDNSSettings saves the DNS settings to the store.
func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error { func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).

View File

@@ -19,21 +19,17 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/util"
nbdns "github.com/netbirdio/netbird/dns" nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
route2 "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/management/server/status"
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
nbroute "github.com/netbirdio/netbird/route" nbroute "github.com/netbirdio/netbird/route"
route2 "github.com/netbirdio/netbird/route"
) )
func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) { func runTestForAllEngines(t *testing.T, testDataFile string, f func(t *testing.T, store Store)) {
@@ -3247,6 +3243,132 @@ func TestSqlStore_SaveGroups_LargeBatch(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 8003, len(accountGroups)) require.Equal(t, 8003, len(accountGroups))
} }
func TestSqlStore_GetAccountRoutes(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
tests := []struct {
name string
accountID string
expectedCount int
}{
{
name: "retrieve routes by existing account ID",
accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b",
expectedCount: 1,
},
{
name: "non-existing account ID",
accountID: "nonexistent",
expectedCount: 0,
},
{
name: "empty account ID",
accountID: "",
expectedCount: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routes, err := store.GetAccountRoutes(context.Background(), LockingStrengthShare, tt.accountID)
require.NoError(t, err)
require.Len(t, routes, tt.expectedCount)
})
}
}
func TestSqlStore_GetRouteByID(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
tests := []struct {
name string
routeID string
expectError bool
}{
{
name: "retrieve existing route",
routeID: "ct03t427qv97vmtmglog",
expectError: false,
},
{
name: "retrieve non-existing route",
routeID: "non-existing",
expectError: true,
},
{
name: "retrieve with empty route ID",
routeID: "",
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, tt.routeID)
if tt.expectError {
require.Error(t, err)
sErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, sErr.Type(), status.NotFound)
require.Nil(t, route)
} else {
require.NoError(t, err)
require.NotNil(t, route)
require.Equal(t, tt.routeID, string(route.ID))
}
})
}
}
func TestSqlStore_SaveRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
route := &route2.Route{
ID: "route-id",
AccountID: accountID,
Network: netip.MustParsePrefix("10.10.0.0/16"),
NetID: "netID",
PeerGroups: []string{"routeA"},
NetworkType: route2.IPv4Network,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"groupA"},
AccessControlGroups: []string{},
}
err = store.SaveRoute(context.Background(), LockingStrengthUpdate, route)
require.NoError(t, err)
saveRoute, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, string(route.ID))
require.NoError(t, err)
require.Equal(t, route, saveRoute)
}
func TestSqlStore_DeleteRoute(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())
t.Cleanup(cleanup)
require.NoError(t, err)
accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b"
routeID := "ct03t427qv97vmtmglog"
err = store.DeleteRoute(context.Background(), LockingStrengthUpdate, accountID, routeID)
require.NoError(t, err)
route, err := store.GetRouteByID(context.Background(), LockingStrengthShare, accountID, routeID)
require.Error(t, err)
require.Nil(t, route)
}
func TestSqlStore_GetAccountMeta(t *testing.T) { func TestSqlStore_GetAccountMeta(t *testing.T) {
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir())

View File

@@ -145,7 +145,9 @@ type Store interface {
DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error DeleteSetupKey(ctx context.Context, lockStrength LockingStrength, accountID, keyID string) error
GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error)
GetRouteByID(ctx context.Context, lockStrength LockingStrength, routeID string, accountID string) (*route.Route, error) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) (*route.Route, error)
SaveRoute(ctx context.Context, lockStrength LockingStrength, route *route.Route) error
DeleteRoute(ctx context.Context, lockStrength LockingStrength, accountID, routeID string) error
GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*dns.NameServerGroup, error)
GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, nameServerGroupID string, accountID string) (*dns.NameServerGroup, error)

View File

@@ -38,4 +38,5 @@ INSERT INTO "groups" VALUES('cfefqs706sqkneg59g2g','bf1c8084-ba50-4ce7-9439-3465
INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}'); INSERT INTO posture_checks VALUES('csplshq7qv948l48f7t0','NetBird Version > 0.32.0','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"NBVersionCheck":{"MinVersion":"0.31.0"}}');
INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}'); INSERT INTO posture_checks VALUES('cspnllq7qv95uq1r4k90','Allow Berlin and Deny local network 172.16.1.0/24','','bf1c8084-ba50-4ce7-9439-34653001fc3b','{"GeoLocationCheck":{"Locations":[{"CountryCode":"DE","CityName":"Berlin"}],"Action":"allow"},"PeerNetworkRangeCheck":{"Action":"deny","Ranges":["172.16.1.0/24"]}}');
INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0); INSERT INTO name_server_groups VALUES('csqdelq7qv97ncu7d9t0','bf1c8084-ba50-4ce7-9439-34653001fc3b','Google DNS','Google DNS Servers','[{"IP":"8.8.8.8","NSType":1,"Port":53},{"IP":"8.8.4.4","NSType":1,"Port":53}]','["cfefqs706sqkneg59g2g"]',1,'[]',1,0);
INSERT INTO routes VALUES('ct03t427qv97vmtmglog','bf1c8084-ba50-4ce7-9439-34653001fc3b','"10.10.0.0/16"',NULL,0,'aws-eu-central-1-vpc','Production VPC in Frankfurt','ct03r5q7qv97vmtmglng',NULL,1,1,9999,1,'["cfefqs706sqkneg59g2g"]',NULL);
INSERT INTO installations VALUES(1,''); INSERT INTO installations VALUES(1,'');