diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 01b7edc48..48ef786dd 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -42,6 +42,10 @@ const ( dnsPolicyConfigConfigOptionsKey = "ConfigOptions" dnsPolicyConfigConfigOptionsValue = 0x8 + // NRPT rules cannot handle more than 50 domains per rule. + // This is an undocumented Windows limitation. + nrptMaxDomainsPerRule = 50 + interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` interfaceConfigNameServerKey = "NameServer" interfaceConfigSearchListKey = "SearchList" @@ -239,23 +243,32 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) { // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 - for i, domain := range domains { - localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) - gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) - singleDomain := []string{domain} + // NRPT rules have an undocumented restriction: each rule can only handle up to 50 domains. + // We need to batch domains into chunks and create one NRPT rule per batch. + ruleIndex := 0 + for i := 0; i < len(domains); i += nrptMaxDomainsPerRule { + end := i + nrptMaxDomainsPerRule + if end > len(domains) { + end = len(domains) + } + batchDomains := domains[i:end] - if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) + localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex) + gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex) + + if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil { + return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err) } if r.gpo { - if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure gpo DNS policy: %w", err) + if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil { + return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex, err) } } - log.Debugf("added NRPT entry for domain: %s", domain) + log.Debugf("added NRPT rule %d with %d domains", ruleIndex, len(batchDomains)) + ruleIndex++ } if r.gpo { @@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr } } - log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains) - return len(domains), nil + log.Infof("added %d NRPT rules for %d domains. Domain list: %s", ruleIndex, len(domains), domains) + return ruleIndex, nil } func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go index 19496bf5a..a150ace96 100644 --- a/client/internal/dns/host_windows_test.go +++ b/client/internal/dns/host_windows_test.go @@ -97,6 +97,107 @@ func registryKeyExists(path string) (bool, error) { } func cleanupRegistryKeys(*testing.T) { - cfg := ®istryConfigurator{nrptEntryCount: 10} + // Clean up more entries to account for batching tests with many domains + cfg := ®istryConfigurator{nrptEntryCount: 20} _ = cfg.removeDNSMatchPolicies() } + +// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules +// with a maximum of 50 domains per rule (Windows limitation). +func TestNRPTDomainBatching(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + testCases := []struct { + name string + domainCount int + expectedRuleCount int + }{ + { + name: "Less than 50 domains (single rule)", + domainCount: 30, + expectedRuleCount: 1, + }, + { + name: "Exactly 50 domains (single rule)", + domainCount: 50, + expectedRuleCount: 1, + }, + { + name: "51 domains (two rules)", + domainCount: 51, + expectedRuleCount: 2, + }, + { + name: "100 domains (two rules)", + domainCount: 100, + expectedRuleCount: 2, + }, + { + name: "125 domains (three rules: 50+50+25)", + domainCount: 125, + expectedRuleCount: 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Clean up before each subtest + cleanupRegistryKeys(t) + + // Generate domains + domains := make([]DomainConfig, tc.domainCount) + for i := 0; i < tc.domainCount; i++ { + domains[i] = DomainConfig{ + Domain: fmt.Sprintf("domain%d.com", i+1), + MatchOnly: true, + } + } + + config := HostDNSConfig{ + ServerIP: testIP, + Domains: domains, + } + + err := cfg.applyDNSConfig(config, nil) + require.NoError(t, err) + + // Verify that exactly expectedRuleCount rules were created + assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount, + "Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount) + + // Verify all expected rules exist + for i := 0; i < tc.expectedRuleCount; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "NRPT rule %d should exist", i) + } + + // Verify no extra rules were created + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount)) + require.NoError(t, err) + assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount) + }) + } +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 0f89b9016..99950af51 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -84,3 +84,13 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } + +// BeginBatch mock implementation of BeginBatch from Server interface +func (m *MockServer) BeginBatch() { + // Mock implementation - no-op +} + +// EndBatch mock implementation of EndBatch from Server interface +func (m *MockServer) EndBatch() { + // Mock implementation - no-op +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 4d4fcc06e..4d1af20d6 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -41,6 +41,8 @@ type IosDnsManager interface { type Server interface { RegisterHandler(domains domain.List, handler dns.Handler, priority int) DeregisterHandler(domains domain.List, priority int) + BeginBatch() + EndBatch() Initialize() error Stop() DnsIP() netip.Addr @@ -83,6 +85,7 @@ type DefaultServer struct { currentConfigHash uint64 handlerChain *HandlerChain extraDomains map[domain.Domain]int + batchMode bool mgmtCacheResolver *mgmt.Resolver @@ -230,7 +233,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler // convert to zone with simple ref counter s.extraDomains[toZone(domain)]++ } - s.applyHostConfig() + if !s.batchMode { + s.applyHostConfig() + } } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { @@ -259,6 +264,28 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { delete(s.extraDomains, zone) } } + if !s.batchMode { + s.applyHostConfig() + } +} + +// BeginBatch starts batch mode for DNS handler registration/deregistration. +// In batch mode, applyHostConfig() is not called after each handler operation, +// allowing multiple handlers to be registered/deregistered efficiently. +// Must be followed by EndBatch() to apply the accumulated changes. +func (s *DefaultServer) BeginBatch() { + s.mux.Lock() + defer s.mux.Unlock() + log.Infof("DNS batch mode enabled") + s.batchMode = true +} + +// EndBatch ends batch mode and applies all accumulated DNS configuration changes. +func (s *DefaultServer) EndBatch() { + s.mux.Lock() + defer s.mux.Unlock() + log.Infof("DNS batch mode disabled, applying accumulated changes") + s.batchMode = false s.applyHostConfig() } @@ -508,7 +535,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.currentConfig.RouteAll = false } - s.applyHostConfig() + if !s.batchMode { + s.applyHostConfig() + } s.shutdownWg.Add(1) go func() { @@ -872,7 +901,9 @@ func (s *DefaultServer) upstreamCallbacks( } } - s.applyHostConfig() + if !s.batchMode { + s.applyHostConfig() + } go func() { if err := s.stateManager.PersistState(s.ctx); err != nil { @@ -907,7 +938,9 @@ func (s *DefaultServer) upstreamCallbacks( s.registerHandler([]string{nbdns.RootZone}, handler, priority) } - s.applyHostConfig() + if !s.batchMode { + s.applyHostConfig() + } s.updateNSState(nsGroup, nil, true) } diff --git a/client/internal/dns/server_export_test.go b/client/internal/dns/server_export_test.go index 1fa343b52..25d08d698 100644 --- a/client/internal/dns/server_export_test.go +++ b/client/internal/dns/server_export_test.go @@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) { t.Errorf("invalid dns server instance: %s", err) } - if srvB != srv { + mockSrvB, ok := srvB.(*MockServer) + if !ok { + t.Errorf("returned server is not a MockServer") + } + + if mockSrvB != srv { t.Errorf("mismatch dns instances") } } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 2baa0e668..c565c56c7 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -337,6 +337,13 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { } var merr *multierror.Error + + // Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation + if m.dnsServer != nil { + m.dnsServer.BeginBatch() + defer m.dnsServer.EndBatch() + } + for id, handler := range toRemove { if err := handler.RemoveRoute(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err))