Compare commits

...

5 Commits

Author SHA1 Message Date
Zoltán Papp
85e4afe100 Refactor route manager to use sync.Once for wgInterface access
- Added `toInterface` helper to ensure `wgInterface.ToInterface()` is called once.
- Updated route addition/removal methods to use `toInterface` for efficiency.
2026-02-03 15:26:37 +01:00
Zoltán Papp
a72e0e67e5 Optimize Windows DNS performance with domain batching and batch mode
Implement two-layer optimization to reduce Windows NRPT registry operations:

1. Domain Batching (host_windows.go):
  - Batch up to 50 domains per NRPT rule (Windows undocumented limit)
  - Reduces NRPT rules by ~97% (e.g., 184 domains: 184 rules → 4 rules)
  - Modified addDNSMatchPolicy() to create batched NRPT entries
  - Added comprehensive tests in host_windows_test.go

2. Batch Mode (server.go):
  - Added BeginBatch/EndBatch methods to defer DNS updates
  - Modified RegisterHandler/DeregisterHandler to skip applyHostConfig in batch mode
  - Protected all applyHostConfig() calls with batch mode checks
  - Updated route manager to wrap route operations with batch calls
2026-02-03 00:15:04 +01:00
Zoltán Papp
a33f60e3fd Adds timing measurement to handleSync to help diagnose sync performance issues 2026-02-01 00:12:11 +01:00
Viktor Liu
0c990ab662 [client] Add block inbound option to the embed client (#5215) 2026-01-30 10:42:39 +01:00
Viktor Liu
101c813e98 [client] Add macOS default resolvers as fallback (#5201) 2026-01-30 10:42:14 +01:00
13 changed files with 413 additions and 23 deletions

View File

@@ -69,6 +69,8 @@ type Options struct {
StatePath string StatePath string
// DisableClientRoutes disables the client routes // DisableClientRoutes disables the client routes
DisableClientRoutes bool DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
} }
// validateCredentials checks that exactly one credential type is provided // validateCredentials checks that exactly one credential type is provided
@@ -137,6 +139,7 @@ func New(opts Options) (*Client, error) {
PreSharedKey: &opts.PreSharedKey, PreSharedKey: &opts.PreSharedKey,
DisableServerRoutes: &t, DisableServerRoutes: &t,
DisableClientRoutes: &opts.DisableClientRoutes, DisableClientRoutes: &opts.DisableClientRoutes,
BlockInbound: &opts.BlockInbound,
} }
if opts.ConfigPath != "" { if opts.ConfigPath != "" {
config, err = profilemanager.UpdateOrCreateConfig(input) config, err = profilemanager.UpdateOrCreateConfig(input)

View File

@@ -9,8 +9,10 @@ import (
"io" "io"
"net/netip" "net/netip"
"os/exec" "os/exec"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
@@ -38,6 +40,9 @@ const (
type systemConfigurator struct { type systemConfigurator struct {
createdKeys map[string]struct{} createdKeys map[string]struct{}
systemDNSSettings SystemDNSSettings systemDNSSettings SystemDNSSettings
mu sync.RWMutex
origNameservers []netip.Addr
} }
func newHostManager() (*systemConfigurator, error) { func newHostManager() (*systemConfigurator, error) {
@@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
} }
var dnsSettings SystemDNSSettings var dnsSettings SystemDNSSettings
var serverAddresses []netip.Addr
inSearchDomainsArray := false inSearchDomainsArray := false
inServerAddressesArray := false inServerAddressesArray := false
@@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) dnsSettings.Domains = append(dnsSettings.Domains, searchDomain)
} else if inServerAddressesArray { } else if inServerAddressesArray {
address := strings.Split(line, " : ")[1] address := strings.Split(line, " : ")[1]
if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() {
dnsSettings.ServerIP = ip.Unmap() ip = ip.Unmap()
inServerAddressesArray = false // Stop reading after finding the first IPv4 address serverAddresses = append(serverAddresses, ip)
if !dnsSettings.ServerIP.IsValid() && ip.Is4() {
dnsSettings.ServerIP = ip
}
} }
} }
} }
@@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) {
// default to 53 port // default to 53 port
dnsSettings.ServerPort = DefaultPort dnsSettings.ServerPort = DefaultPort
s.mu.Lock()
s.origNameservers = serverAddresses
s.mu.Unlock()
return dnsSettings, nil return dnsSettings, nil
} }
func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
s.mu.RLock()
defer s.mu.RUnlock()
return slices.Clone(s.origNameservers)
}
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
err := s.addDNSState(key, domains, ip, port, true) err := s.addDNSState(key, domains, ip, port, true)
if err != nil { if err != nil {

View File

@@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error {
_, err := cmd.CombinedOutput() _, err := cmd.CombinedOutput()
return err return err
} }
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
cleanup := func() {
_ = sm.Stop(context.Background())
for _, key := range []string{searchKey, matchKey, localKey} {
_ = removeTestDNSKey(key)
}
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}

View File

@@ -42,6 +42,10 @@ const (
dnsPolicyConfigConfigOptionsKey = "ConfigOptions" dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
dnsPolicyConfigConfigOptionsValue = 0x8 dnsPolicyConfigConfigOptionsValue = 0x8
// NRPT rules cannot handle more than 50 domains per rule.
// This is an undocumented Windows limitation.
nrptMaxDomainsPerRule = 50
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServerKey = "NameServer" interfaceConfigNameServerKey = "NameServer"
interfaceConfigSearchListKey = "SearchList" interfaceConfigSearchListKey = "SearchList"
@@ -239,23 +243,32 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
for i, domain := range domains {
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
singleDomain := []string{domain} // NRPT rules have an undocumented restriction: each rule can only handle up to 50 domains.
// We need to batch domains into chunks and create one NRPT rule per batch.
ruleIndex := 0
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
end := i + nrptMaxDomainsPerRule
if end > len(domains) {
end = len(domains)
}
batchDomains := domains[i:end]
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
} }
if r.gpo { if r.gpo {
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
return i, fmt.Errorf("configure gpo DNS policy: %w", err) return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex, err)
} }
} }
log.Debugf("added NRPT entry for domain: %s", domain) log.Debugf("added NRPT rule %d with %d domains", ruleIndex, len(batchDomains))
ruleIndex++
} }
if r.gpo { if r.gpo {
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
} }
} }
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains) log.Infof("added %d NRPT rules for %d domains. Domain list: %s", ruleIndex, len(domains), domains)
return len(domains), nil return ruleIndex, nil
} }
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {

View File

@@ -97,6 +97,107 @@ func registryKeyExists(path string) (bool, error) {
} }
func cleanupRegistryKeys(*testing.T) { func cleanupRegistryKeys(*testing.T) {
cfg := &registryConfigurator{nrptEntryCount: 10} // Clean up more entries to account for batching tests with many domains
cfg := &registryConfigurator{nrptEntryCount: 20}
_ = cfg.removeDNSMatchPolicies() _ = cfg.removeDNSMatchPolicies()
} }
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules
// with a maximum of 50 domains per rule (Windows limitation).
func TestNRPTDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping registry integration test in short mode")
}
defer cleanupRegistryKeys(t)
cleanupRegistryKeys(t)
testIP := netip.MustParseAddr("100.64.0.1")
// Create a test interface registry key so updateSearchDomains doesn't fail
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
require.NoError(t, err, "Should create test interface registry key")
testKey.Close()
defer func() {
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
}()
cfg := &registryConfigurator{
guid: testGUID,
gpo: false,
}
testCases := []struct {
name string
domainCount int
expectedRuleCount int
}{
{
name: "Less than 50 domains (single rule)",
domainCount: 30,
expectedRuleCount: 1,
},
{
name: "Exactly 50 domains (single rule)",
domainCount: 50,
expectedRuleCount: 1,
},
{
name: "51 domains (two rules)",
domainCount: 51,
expectedRuleCount: 2,
},
{
name: "100 domains (two rules)",
domainCount: 100,
expectedRuleCount: 2,
},
{
name: "125 domains (three rules: 50+50+25)",
domainCount: 125,
expectedRuleCount: 3,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Clean up before each subtest
cleanupRegistryKeys(t)
// Generate domains
domains := make([]DomainConfig, tc.domainCount)
for i := 0; i < tc.domainCount; i++ {
domains[i] = DomainConfig{
Domain: fmt.Sprintf("domain%d.com", i+1),
MatchOnly: true,
}
}
config := HostDNSConfig{
ServerIP: testIP,
Domains: domains,
}
err := cfg.applyDNSConfig(config, nil)
require.NoError(t, err)
// Verify that exactly expectedRuleCount rules were created
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
// Verify all expected rules exist
for i := 0; i < tc.expectedRuleCount; i++ {
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
require.NoError(t, err)
assert.True(t, exists, "NRPT rule %d should exist", i)
}
// Verify no extra rules were created
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
require.NoError(t, err)
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
})
}
}

View File

@@ -84,3 +84,13 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
return nil return nil
} }
// BeginBatch mock implementation of BeginBatch from Server interface
func (m *MockServer) BeginBatch() {
// Mock implementation - no-op
}
// EndBatch mock implementation of EndBatch from Server interface
func (m *MockServer) EndBatch() {
// Mock implementation - no-op
}

View File

@@ -41,6 +41,8 @@ type IosDnsManager interface {
type Server interface { type Server interface {
RegisterHandler(domains domain.List, handler dns.Handler, priority int) RegisterHandler(domains domain.List, handler dns.Handler, priority int)
DeregisterHandler(domains domain.List, priority int) DeregisterHandler(domains domain.List, priority int)
BeginBatch()
EndBatch()
Initialize() error Initialize() error
Stop() Stop()
DnsIP() netip.Addr DnsIP() netip.Addr
@@ -83,6 +85,7 @@ type DefaultServer struct {
currentConfigHash uint64 currentConfigHash uint64
handlerChain *HandlerChain handlerChain *HandlerChain
extraDomains map[domain.Domain]int extraDomains map[domain.Domain]int
batchMode bool
mgmtCacheResolver *mgmt.Resolver mgmtCacheResolver *mgmt.Resolver
@@ -230,7 +233,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
// convert to zone with simple ref counter // convert to zone with simple ref counter
s.extraDomains[toZone(domain)]++ s.extraDomains[toZone(domain)]++
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
} }
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
@@ -259,6 +264,28 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
delete(s.extraDomains, zone) delete(s.extraDomains, zone)
} }
} }
if !s.batchMode {
s.applyHostConfig()
}
}
// BeginBatch starts batch mode for DNS handler registration/deregistration.
// In batch mode, applyHostConfig() is not called after each handler operation,
// allowing multiple handlers to be registered/deregistered efficiently.
// Must be followed by EndBatch() to apply the accumulated changes.
func (s *DefaultServer) BeginBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("DNS batch mode enabled")
s.batchMode = true
}
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
func (s *DefaultServer) EndBatch() {
s.mux.Lock()
defer s.mux.Unlock()
log.Infof("DNS batch mode disabled, applying accumulated changes")
s.batchMode = false
s.applyHostConfig() s.applyHostConfig()
} }
@@ -508,7 +535,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.currentConfig.RouteAll = false s.currentConfig.RouteAll = false
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
s.shutdownWg.Add(1) s.shutdownWg.Add(1)
go func() { go func() {
@@ -615,7 +644,7 @@ func (s *DefaultServer) applyHostConfig() {
s.registerFallback(config) s.registerFallback(config)
} }
// registerFallback registers original nameservers as low-priority fallback handlers // registerFallback registers original nameservers as low-priority fallback handlers.
func (s *DefaultServer) registerFallback(config HostDNSConfig) { func (s *DefaultServer) registerFallback(config HostDNSConfig) {
hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS)
if !ok { if !ok {
@@ -624,6 +653,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
originalNameservers := hostMgrWithNS.getOriginalNameservers() originalNameservers := hostMgrWithNS.getOriginalNameservers()
if len(originalNameservers) == 0 { if len(originalNameservers) == 0 {
s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback)
return return
} }
@@ -871,7 +901,9 @@ func (s *DefaultServer) upstreamCallbacks(
} }
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
go func() { go func() {
if err := s.stateManager.PersistState(s.ctx); err != nil { if err := s.stateManager.PersistState(s.ctx); err != nil {
@@ -906,7 +938,9 @@ func (s *DefaultServer) upstreamCallbacks(
s.registerHandler([]string{nbdns.RootZone}, handler, priority) s.registerHandler([]string{nbdns.RootZone}, handler, priority)
} }
s.applyHostConfig() if !s.batchMode {
s.applyHostConfig()
}
s.updateNSState(nsGroup, nil, true) s.updateNSState(nsGroup, nil, true)
} }

View File

@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
t.Errorf("invalid dns server instance: %s", err) t.Errorf("invalid dns server instance: %s", err)
} }
if srvB != srv { mockSrvB, ok := srvB.(*MockServer)
if !ok {
t.Errorf("returned server is not a MockServer")
}
if mockSrvB != srv {
t.Errorf("mismatch dns instances") t.Errorf("mismatch dns instances")
} }
} }

View File

@@ -8,15 +8,21 @@ import (
type MockResponseWriter struct { type MockResponseWriter struct {
WriteMsgFunc func(m *dns.Msg) error WriteMsgFunc func(m *dns.Msg) error
lastResponse *dns.Msg
} }
func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error {
rw.lastResponse = m
if rw.WriteMsgFunc != nil { if rw.WriteMsgFunc != nil {
return rw.WriteMsgFunc(m) return rw.WriteMsgFunc(m)
} }
return nil return nil
} }
func (rw *MockResponseWriter) GetLastResponse() *dns.Msg {
return rw.lastResponse
}
func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil }
func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil }
func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil } func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }

View File

@@ -828,6 +828,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate
} }
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
started := time.Now()
defer func() {
log.Infof("sync finished in %s", time.Since(started))
}()
e.syncMsgMux.Lock() e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock() defer e.syncMsgMux.Unlock()

View File

@@ -14,6 +14,7 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
) )
@@ -37,6 +38,11 @@ func New() *NetworkMonitor {
// Listen begins monitoring network changes. When a change is detected, this function will return without error. // Listen begins monitoring network changes. When a change is detected, this function will return without error.
func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) { func (nw *NetworkMonitor) Listen(ctx context.Context) (err error) {
if netstack.IsEnabled() {
log.Debugf("Network monitor: skipping in netstack mode")
return nil
}
nw.mu.Lock() nw.mu.Lock()
if nw.cancel != nil { if nw.cancel != nil {
nw.mu.Unlock() nw.mu.Unlock()

View File

@@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) {
} }
func (m *DefaultManager) setupRefCounters(useNoop bool) { func (m *DefaultManager) setupRefCounters(useNoop bool) {
var once sync.Once
var wgIface *net.Interface
toInterface := func() *net.Interface {
once.Do(func() {
wgIface = m.wgInterface.ToInterface()
})
return wgIface
}
m.routeRefCounter = refcounter.New( m.routeRefCounter = refcounter.New(
func(prefix netip.Prefix, _ struct{}) (struct{}, error) { func(prefix netip.Prefix, _ struct{}) (struct{}, error) {
return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface())
}, },
func(prefix netip.Prefix, _ struct{}) error { func(prefix netip.Prefix, _ struct{}) error {
return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) return m.sysOps.RemoveVPNRoute(prefix, toInterface())
}, },
) )
@@ -337,6 +346,13 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error {
} }
var merr *multierror.Error var merr *multierror.Error
// Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation
if m.dnsServer != nil {
m.dnsServer.BeginBatch()
defer m.dnsServer.EndBatch()
}
for id, handler := range toRemove { for id, handler := range toRemove {
if err := handler.RemoveRoute(); err != nil { if err := handler.RemoveRoute(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err)) merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))

View File

@@ -9,6 +9,8 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface/netstack"
) )
// WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine // WGIfaceMonitor monitors the WireGuard interface lifecycle and restarts the engine
@@ -35,6 +37,11 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
return false, errors.New("not supported on mobile platforms") return false, errors.New("not supported on mobile platforms")
} }
if netstack.IsEnabled() {
log.Debugf("Interface monitor: skipped in netstack mode")
return false, nil
}
if ifaceName == "" { if ifaceName == "" {
log.Debugf("Interface monitor: empty interface name, skipping monitor") log.Debugf("Interface monitor: empty interface name, skipping monitor")
return false, errors.New("empty interface name") return false, errors.New("empty interface name")