mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 11:46:40 +00:00
Compare commits
5 Commits
v0.64.3
...
fix/win-dn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
85e4afe100 | ||
|
|
a72e0e67e5 | ||
|
|
a33f60e3fd | ||
|
|
0c990ab662 | ||
|
|
101c813e98 |
@@ -69,6 +69,8 @@ type Options struct {
|
|||||||
StatePath string
|
StatePath string
|
||||||
// DisableClientRoutes disables the client routes
|
// DisableClientRoutes disables the client routes
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
|
// BlockInbound blocks all inbound connections from peers
|
||||||
|
BlockInbound bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -137,6 +139,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
PreSharedKey: &opts.PreSharedKey,
|
PreSharedKey: &opts.PreSharedKey,
|
||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
|
BlockInbound: &opts.BlockInbound,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
|
|||||||
@@ -9,8 +9,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
@@ -38,6 +40,9 @@ const (
|
|||||||
type systemConfigurator struct {
|
type systemConfigurator struct {
|
||||||
createdKeys map[string]struct{}
|
createdKeys map[string]struct{}
|
||||||
systemDNSSettings SystemDNSSettings
|
systemDNSSettings SystemDNSSettings
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
origNameservers []netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager() (*systemConfigurator, error) {
|
func newHostManager() (*systemConfigurator, error) {
|
||||||
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var dnsSettings SystemDNSSettings
|
var dnsSettings SystemDNSSettings
|
||||||
|
var serverAddresses []netip.Addr
|
||||||
inSearchDomainsArray := false
|
inSearchDomainsArray := false
|
||||||
inServerAddressesArray := false
|
inServerAddressesArray := false
|
||||||
|
|
||||||
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
|
||||||
} else if inServerAddressesArray {
|
} else if inServerAddressesArray {
|
||||||
address := strings.Split(line, " : ")[1]
|
address := strings.Split(line, " : ")[1]
|
||||||
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() {
|
if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
|
||||||
dnsSettings.ServerIP = ip.Unmap()
|
ip = ip.Unmap()
|
||||||
inServerAddressesArray = false // Stop reading after finding the first IPv4 address
|
serverAddresses = append(serverAddresses, ip)
|
||||||
|
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
|
||||||
|
dnsSettings.ServerIP = ip
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
|
|||||||
// default to 53 port
|
// default to 53 port
|
||||||
dnsSettings.ServerPort = DefaultPort
|
dnsSettings.ServerPort = DefaultPort
|
||||||
|
|
||||||
|
s.mu.Lock()
|
||||||
|
s.origNameservers = serverAddresses
|
||||||
|
s.mu.Unlock()
|
||||||
|
|
||||||
return dnsSettings, nil
|
return dnsSettings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
return slices.Clone(s.origNameservers)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
err := s.addDNSState(key, domains, ip, port, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
|
|||||||
_, err := cmd.CombinedOutput()
|
_, err := cmd.CombinedOutput()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameservers(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
origNameservers: []netip.Addr{
|
||||||
|
netip.MustParseAddr("8.8.8.8"),
|
||||||
|
netip.MustParseAddr("1.1.1.1"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
assert.Len(t, servers, 2)
|
||||||
|
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
||||||
|
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
|
||||||
|
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
||||||
|
|
||||||
|
for _, server := range servers {
|
||||||
|
assert.True(t, server.IsValid(), "server address should be valid")
|
||||||
|
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
stateFile := filepath.Join(tmpDir, "state.json")
|
||||||
|
sm := statemanager.New(stateFile)
|
||||||
|
sm.RegisterState(&ShutdownState{})
|
||||||
|
sm.Start()
|
||||||
|
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
cleanup := func() {
|
||||||
|
_ = sm.Stop(context.Background())
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return configurator, sm, cleanup
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversNoTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
routeAll bool
|
||||||
|
}{
|
||||||
|
{"routeall_false", false},
|
||||||
|
{"routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
for _, srv := range initialServers {
|
||||||
|
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.routeAll,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i <= 2; i++ {
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
||||||
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
initialRoute bool
|
||||||
|
}{
|
||||||
|
{"start_with_routeall_false", false},
|
||||||
|
{"start_with_routeall_true", true},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
_, err := configurator.getSystemDNSSettings()
|
||||||
|
require.NoError(t, err)
|
||||||
|
initialServers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("Initial servers: %v", initialServers)
|
||||||
|
require.NotEmpty(t, initialServers)
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netbirdIP,
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: tc.initialRoute,
|
||||||
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// First apply
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers := configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle RouteAll
|
||||||
|
config.RouteAll = !tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
// Toggle back
|
||||||
|
config.RouteAll = tc.initialRoute
|
||||||
|
err = configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
servers = configurator.getOriginalNameservers()
|
||||||
|
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
||||||
|
assert.Equal(t, initialServers, servers)
|
||||||
|
|
||||||
|
for _, srv := range servers {
|
||||||
|
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ const (
|
|||||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||||
|
|
||||||
|
// NRPT rules cannot handle more than 50 domains per rule.
|
||||||
|
// This is an undocumented Windows limitation.
|
||||||
|
nrptMaxDomainsPerRule = 50
|
||||||
|
|
||||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
@@ -239,23 +243,32 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
|||||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
for i, domain := range domains {
|
|
||||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
|
||||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
|
||||||
|
|
||||||
singleDomain := []string{domain}
|
// NRPT rules have an undocumented restriction: each rule can only handle up to 50 domains.
|
||||||
|
// We need to batch domains into chunks and create one NRPT rule per batch.
|
||||||
|
ruleIndex := 0
|
||||||
|
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
|
||||||
|
end := i + nrptMaxDomainsPerRule
|
||||||
|
if end > len(domains) {
|
||||||
|
end = len(domains)
|
||||||
|
}
|
||||||
|
batchDomains := domains[i:end]
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
|
||||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
|
||||||
|
|
||||||
|
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
|
||||||
|
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.gpo {
|
if r.gpo {
|
||||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
|
||||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
log.Debugf("added NRPT rule %d with %d domains", ruleIndex, len(batchDomains))
|
||||||
|
ruleIndex++
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.gpo {
|
if r.gpo {
|
||||||
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
log.Infof("added %d NRPT rules for %d domains. Domain list: %s", ruleIndex, len(domains), domains)
|
||||||
return len(domains), nil
|
return ruleIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||||
|
|||||||
@@ -97,6 +97,107 @@ func registryKeyExists(path string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cleanupRegistryKeys(*testing.T) {
|
func cleanupRegistryKeys(*testing.T) {
|
||||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
// Clean up more entries to account for batching tests with many domains
|
||||||
|
cfg := ®istryConfigurator{nrptEntryCount: 20}
|
||||||
_ = cfg.removeDNSMatchPolicies()
|
_ = cfg.removeDNSMatchPolicies()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules
|
||||||
|
// with a maximum of 50 domains per rule (Windows limitation).
|
||||||
|
func TestNRPTDomainBatching(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping registry integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer cleanupRegistryKeys(t)
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
testIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||||
|
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||||
|
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||||
|
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||||
|
require.NoError(t, err, "Should create test interface registry key")
|
||||||
|
testKey.Close()
|
||||||
|
defer func() {
|
||||||
|
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := ®istryConfigurator{
|
||||||
|
guid: testGUID,
|
||||||
|
gpo: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
domainCount int
|
||||||
|
expectedRuleCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Less than 50 domains (single rule)",
|
||||||
|
domainCount: 30,
|
||||||
|
expectedRuleCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Exactly 50 domains (single rule)",
|
||||||
|
domainCount: 50,
|
||||||
|
expectedRuleCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "51 domains (two rules)",
|
||||||
|
domainCount: 51,
|
||||||
|
expectedRuleCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100 domains (two rules)",
|
||||||
|
domainCount: 100,
|
||||||
|
expectedRuleCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "125 domains (three rules: 50+50+25)",
|
||||||
|
domainCount: 125,
|
||||||
|
expectedRuleCount: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Clean up before each subtest
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
// Generate domains
|
||||||
|
domains := make([]DomainConfig, tc.domainCount)
|
||||||
|
for i := 0; i < tc.domainCount; i++ {
|
||||||
|
domains[i] = DomainConfig{
|
||||||
|
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||||
|
MatchOnly: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: domains,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.applyDNSConfig(config, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that exactly expectedRuleCount rules were created
|
||||||
|
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
|
||||||
|
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
|
||||||
|
|
||||||
|
// Verify all expected rules exist
|
||||||
|
for i := 0; i < tc.expectedRuleCount; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "NRPT rule %d should exist", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no extra rules were created
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -84,3 +84,13 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
|||||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||||
|
func (m *MockServer) BeginBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndBatch mock implementation of EndBatch from Server interface
|
||||||
|
func (m *MockServer) EndBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ type IosDnsManager interface {
|
|||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||||
DeregisterHandler(domains domain.List, priority int)
|
DeregisterHandler(domains domain.List, priority int)
|
||||||
|
BeginBatch()
|
||||||
|
EndBatch()
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() netip.Addr
|
DnsIP() netip.Addr
|
||||||
@@ -83,6 +85,7 @@ type DefaultServer struct {
|
|||||||
currentConfigHash uint64
|
currentConfigHash uint64
|
||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
extraDomains map[domain.Domain]int
|
extraDomains map[domain.Domain]int
|
||||||
|
batchMode bool
|
||||||
|
|
||||||
mgmtCacheResolver *mgmt.Resolver
|
mgmtCacheResolver *mgmt.Resolver
|
||||||
|
|
||||||
@@ -230,7 +233,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
|||||||
// convert to zone with simple ref counter
|
// convert to zone with simple ref counter
|
||||||
s.extraDomains[toZone(domain)]++
|
s.extraDomains[toZone(domain)]++
|
||||||
}
|
}
|
||||||
s.applyHostConfig()
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
@@ -259,6 +264,28 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
|||||||
delete(s.extraDomains, zone)
|
delete(s.extraDomains, zone)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginBatch starts batch mode for DNS handler registration/deregistration.
|
||||||
|
// In batch mode, applyHostConfig() is not called after each handler operation,
|
||||||
|
// allowing multiple handlers to be registered/deregistered efficiently.
|
||||||
|
// Must be followed by EndBatch() to apply the accumulated changes.
|
||||||
|
func (s *DefaultServer) BeginBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Infof("DNS batch mode enabled")
|
||||||
|
s.batchMode = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
|
||||||
|
func (s *DefaultServer) EndBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Infof("DNS batch mode disabled, applying accumulated changes")
|
||||||
|
s.batchMode = false
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -508,7 +535,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
s.applyHostConfig()
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
|
|
||||||
s.shutdownWg.Add(1)
|
s.shutdownWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -615,7 +644,7 @@ func (s *DefaultServer) applyHostConfig() {
|
|||||||
s.registerFallback(config)
|
s.registerFallback(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
// registerFallback registers original nameservers as low-priority fallback handlers
|
// registerFallback registers original nameservers as low-priority fallback handlers.
|
||||||
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
||||||
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -624,6 +653,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
|
|
||||||
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
originalNameservers := hostMgrWithNS.getOriginalNameservers()
|
||||||
if len(originalNameservers) == 0 {
|
if len(originalNameservers) == 0 {
|
||||||
|
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -871,7 +901,9 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
s.applyHostConfig()
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
@@ -906,7 +938,9 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.applyHostConfig()
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
|
|||||||
t.Errorf("invalid dns server instance: %s", err)
|
t.Errorf("invalid dns server instance: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if srvB != srv {
|
mockSrvB, ok := srvB.(*MockServer)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("returned server is not a MockServer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mockSrvB != srv {
|
||||||
t.Errorf("mismatch dns instances")
|
t.Errorf("mismatch dns instances")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,15 +8,21 @@ import (
|
|||||||
|
|
||||||
type MockResponseWriter struct {
|
type MockResponseWriter struct {
|
||||||
WriteMsgFunc func(m *dns.Msg) error
|
WriteMsgFunc func(m *dns.Msg) error
|
||||||
|
lastResponse *dns.Msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
|
||||||
|
rw.lastResponse = m
|
||||||
if rw.WriteMsgFunc != nil {
|
if rw.WriteMsgFunc != nil {
|
||||||
return rw.WriteMsgFunc(m)
|
return rw.WriteMsgFunc(m)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
|
||||||
|
return rw.lastResponse
|
||||||
|
}
|
||||||
|
|
||||||
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
|
||||||
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
||||||
|
|||||||
@@ -828,6 +828,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
log.Infof("sync finished in %s", time.Since(started))
|
||||||
|
}()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
|
|||||||
|
|
||||||
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
// Listen begins monitoring network changes. When a change is detected, this function will return without error.
|
||||||
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Debugf("Network monitor: skipping in netstack mode")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
nw.mu.Lock()
|
nw.mu.Lock()
|
||||||
if nw.cancel != nil {
|
if nw.cancel != nil {
|
||||||
nw.mu.Unlock()
|
nw.mu.Unlock()
|
||||||
|
|||||||
@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
func (m *DefaultManager) setupRefCounters(useNoop bool) {
|
||||||
|
var once sync.Once
|
||||||
|
var wgIface *net.Interface
|
||||||
|
toInterface := func() *net.Interface {
|
||||||
|
once.Do(func() {
|
||||||
|
wgIface = m.wgInterface.ToInterface()
|
||||||
|
})
|
||||||
|
return wgIface
|
||||||
|
}
|
||||||
|
|
||||||
m.routeRefCounter = refcounter.New(
|
m.routeRefCounter = refcounter.New(
|
||||||
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
|
||||||
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface())
|
return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
func(prefix netip.Prefix, _ struct{}) error {
|
func(prefix netip.Prefix, _ struct{}) error {
|
||||||
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface())
|
return m.sysOps.RemoveVPNRoute(prefix, toInterface())
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -337,6 +346,13 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
|
||||||
|
if m.dnsServer != nil {
|
||||||
|
m.dnsServer.BeginBatch()
|
||||||
|
defer m.dnsServer.EndBatch()
|
||||||
|
}
|
||||||
|
|
||||||
for id, handler := range toRemove {
|
for id, handler := range toRemove {
|
||||||
if err := handler.RemoveRoute(); err != nil {
|
if err := handler.RemoveRoute(); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
|
||||||
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
|||||||
return false, errors.New("not supported on mobile platforms")
|
return false, errors.New("not supported on mobile platforms")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
log.Debugf("Interface monitor: skipped in netstack mode")
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
if ifaceName == "" {
|
if ifaceName == "" {
|
||||||
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
log.Debugf("Interface monitor: empty interface name, skipping monitor")
|
||||||
return false, errors.New("empty interface name")
|
return false, errors.New("empty interface name")
|
||||||
|
|||||||
Reference in New Issue
Block a user