diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index 01b7edc48..9b7a7b52b 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -42,6 +42,8 @@ const ( dnsPolicyConfigConfigOptionsKey = "ConfigOptions" dnsPolicyConfigConfigOptionsValue = 0x8 + nrptMaxDomainsPerRule = 50 + interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces` interfaceConfigNameServerKey = "NameServer" interfaceConfigSearchListKey = "SearchList" @@ -198,10 +200,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager if len(matchDomains) != 0 { count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP) + // Update count even on error to ensure cleanup covers partially created rules + r.nrptEntryCount = count if err != nil { return fmt.Errorf("add dns match policy: %w", err) } - r.nrptEntryCount = count } else { r.nrptEntryCount = 0 } @@ -239,23 +242,33 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error { func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) { // if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored // see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745 - for i, domain := range domains { - localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i) - gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i) - singleDomain := []string{domain} + // We need to batch domains into chunks and create one NRPT rule per batch. + ruleIndex := 0 + for i := 0; i < len(domains); i += nrptMaxDomainsPerRule { + end := i + nrptMaxDomainsPerRule + if end > len(domains) { + end = len(domains) + } + batchDomains := domains[i:end] - if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err) + localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex) + gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex) + + if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil { + return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err) } + // Increment immediately so the caller's cleanup path knows about this rule + ruleIndex++ + if r.gpo { - if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil { - return i, fmt.Errorf("configure gpo DNS policy: %w", err) + if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil { + return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err) } } - log.Debugf("added NRPT entry for domain: %s", domain) + log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains)) } if r.gpo { @@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr } } - log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains) - return len(domains), nil + log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains) + return ruleIndex, nil } func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error { diff --git a/client/internal/dns/host_windows_test.go b/client/internal/dns/host_windows_test.go index 19496bf5a..3cd2b1bd5 100644 --- a/client/internal/dns/host_windows_test.go +++ b/client/internal/dns/host_windows_test.go @@ -12,6 +12,7 @@ import ( // TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up // when the number of match domains decreases between configuration changes. +// With batching enabled (50 domains per rule), we need enough domains to create multiple rules. func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { if testing.Short() { t.Skip("skipping registry integration test in short mode") @@ -37,51 +38,60 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) { gpo: false, } - config5 := HostDNSConfig{ - ServerIP: testIP, - Domains: []DomainConfig{ - {Domain: "domain1.com", MatchOnly: true}, - {Domain: "domain2.com", MatchOnly: true}, - {Domain: "domain3.com", MatchOnly: true}, - {Domain: "domain4.com", MatchOnly: true}, - {Domain: "domain5.com", MatchOnly: true}, - }, + // Create 125 domains which will result in 3 NRPT rules (50+50+25) + domains125 := make([]DomainConfig, 125) + for i := 0; i < 125; i++ { + domains125[i] = DomainConfig{ + Domain: fmt.Sprintf("domain%d.com", i+1), + MatchOnly: true, + } } - err = cfg.applyDNSConfig(config5, nil) + config125 := HostDNSConfig{ + ServerIP: testIP, + Domains: domains125, + } + + err = cfg.applyDNSConfig(config125, nil) require.NoError(t, err) - // Verify all 5 entries exist - for i := 0; i < 5; i++ { + // Verify 3 NRPT rules exist + assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains") + for i := 0; i < 3; i++ { exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) require.NoError(t, err) - assert.True(t, exists, "Entry %d should exist after first config", i) + assert.True(t, exists, "NRPT rule %d should exist after first config", i) } - config2 := HostDNSConfig{ + // Reduce to 75 domains which will result in 2 NRPT rules (50+25) + domains75 := make([]DomainConfig, 75) + for i := 0; i < 75; i++ { + domains75[i] = DomainConfig{ + Domain: fmt.Sprintf("domain%d.com", i+1), + MatchOnly: true, + } + } + + config75 := HostDNSConfig{ ServerIP: testIP, - Domains: []DomainConfig{ - {Domain: "domain1.com", MatchOnly: true}, - {Domain: "domain2.com", MatchOnly: true}, - }, + Domains: domains75, } - err = cfg.applyDNSConfig(config2, nil) + err = cfg.applyDNSConfig(config75, nil) require.NoError(t, err) - // Verify first 2 entries exist + // Verify first 2 NRPT rules exist + assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains") for i := 0; i < 2; i++ { exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) require.NoError(t, err) - assert.True(t, exists, "Entry %d should exist after second config", i) + assert.True(t, exists, "NRPT rule %d should exist after second config", i) } - // Verify entries 2-4 are cleaned up - for i := 2; i < 5; i++ { - exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) - require.NoError(t, err) - assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i) - } + // Verify rule 2 is cleaned up + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2)) + require.NoError(t, err) + assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains") } func registryKeyExists(path string) (bool, error) { @@ -97,6 +107,106 @@ func registryKeyExists(path string) (bool, error) { } func cleanupRegistryKeys(*testing.T) { - cfg := ®istryConfigurator{nrptEntryCount: 10} + // Clean up more entries to account for batching tests with many domains + cfg := ®istryConfigurator{nrptEntryCount: 20} _ = cfg.removeDNSMatchPolicies() } + +// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules. +func TestNRPTDomainBatching(t *testing.T) { + if testing.Short() { + t.Skip("skipping registry integration test in short mode") + } + + defer cleanupRegistryKeys(t) + cleanupRegistryKeys(t) + + testIP := netip.MustParseAddr("100.64.0.1") + + // Create a test interface registry key so updateSearchDomains doesn't fail + testGUID := "{12345678-1234-1234-1234-123456789ABC}" + interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID + testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE) + require.NoError(t, err, "Should create test interface registry key") + testKey.Close() + defer func() { + _ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath) + }() + + cfg := ®istryConfigurator{ + guid: testGUID, + gpo: false, + } + + testCases := []struct { + name string + domainCount int + expectedRuleCount int + }{ + { + name: "Less than 50 domains (single rule)", + domainCount: 30, + expectedRuleCount: 1, + }, + { + name: "Exactly 50 domains (single rule)", + domainCount: 50, + expectedRuleCount: 1, + }, + { + name: "51 domains (two rules)", + domainCount: 51, + expectedRuleCount: 2, + }, + { + name: "100 domains (two rules)", + domainCount: 100, + expectedRuleCount: 2, + }, + { + name: "125 domains (three rules: 50+50+25)", + domainCount: 125, + expectedRuleCount: 3, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Clean up before each subtest + cleanupRegistryKeys(t) + + // Generate domains + domains := make([]DomainConfig, tc.domainCount) + for i := 0; i < tc.domainCount; i++ { + domains[i] = DomainConfig{ + Domain: fmt.Sprintf("domain%d.com", i+1), + MatchOnly: true, + } + } + + config := HostDNSConfig{ + ServerIP: testIP, + Domains: domains, + } + + err := cfg.applyDNSConfig(config, nil) + require.NoError(t, err) + + // Verify that exactly expectedRuleCount rules were created + assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount, + "Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount) + + // Verify all expected rules exist + for i := 0; i < tc.expectedRuleCount; i++ { + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)) + require.NoError(t, err) + assert.True(t, exists, "NRPT rule %d should exist", i) + } + + // Verify no extra rules were created + exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount)) + require.NoError(t, err) + assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount) + }) + } +} diff --git a/client/internal/dns/mock_server.go b/client/internal/dns/mock_server.go index 0f89b9016..fe160e20a 100644 --- a/client/internal/dns/mock_server.go +++ b/client/internal/dns/mock_server.go @@ -84,3 +84,18 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error { func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error { return nil } + +// BeginBatch mock implementation of BeginBatch from Server interface +func (m *MockServer) BeginBatch() { + // Mock implementation - no-op +} + +// EndBatch mock implementation of EndBatch from Server interface +func (m *MockServer) EndBatch() { + // Mock implementation - no-op +} + +// CancelBatch mock implementation of CancelBatch from Server interface +func (m *MockServer) CancelBatch() { + // Mock implementation - no-op +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index c2b01de62..179517bbd 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -45,6 +45,9 @@ type IosDnsManager interface { type Server interface { RegisterHandler(domains domain.List, handler dns.Handler, priority int) DeregisterHandler(domains domain.List, priority int) + BeginBatch() + EndBatch() + CancelBatch() Initialize() error Stop() DnsIP() netip.Addr @@ -87,6 +90,7 @@ type DefaultServer struct { currentConfigHash uint64 handlerChain *HandlerChain extraDomains map[domain.Domain]int + batchMode bool mgmtCacheResolver *mgmt.Resolver @@ -234,7 +238,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler // convert to zone with simple ref counter s.extraDomains[toZone(domain)]++ } - s.applyHostConfig() + if !s.batchMode { + s.applyHostConfig() + } } func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) { @@ -263,9 +269,41 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) { delete(s.extraDomains, zone) } } + if !s.batchMode { + s.applyHostConfig() + } +} + +// BeginBatch starts batch mode for DNS handler registration/deregistration. +// In batch mode, applyHostConfig() is not called after each handler operation, +// allowing multiple handlers to be registered/deregistered efficiently. +// Must be followed by EndBatch() to apply the accumulated changes. +func (s *DefaultServer) BeginBatch() { + s.mux.Lock() + defer s.mux.Unlock() + log.Debugf("DNS batch mode enabled") + s.batchMode = true +} + +// EndBatch ends batch mode and applies all accumulated DNS configuration changes. +func (s *DefaultServer) EndBatch() { + s.mux.Lock() + defer s.mux.Unlock() + log.Debugf("DNS batch mode disabled, applying accumulated changes") + s.batchMode = false s.applyHostConfig() } +// CancelBatch cancels batch mode without applying accumulated changes. +// This is useful when operations fail partway through and you want to +// discard partial state rather than applying it. +func (s *DefaultServer) CancelBatch() { + s.mux.Lock() + defer s.mux.Unlock() + log.Debugf("DNS batch mode cancelled, discarding accumulated changes") + s.batchMode = false +} + func (s *DefaultServer) deregisterHandler(domains []string, priority int) { log.Debugf("deregistering handler with priority %d for %v", priority, domains) @@ -523,6 +561,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error { s.currentConfig.RouteAll = false } + // Always apply host config for management updates, regardless of batch mode s.applyHostConfig() s.shutdownWg.Add(1) @@ -887,6 +926,7 @@ func (s *DefaultServer) upstreamCallbacks( } } + // Always apply host config when nameserver goes down, regardless of batch mode s.applyHostConfig() go func() { @@ -922,6 +962,7 @@ func (s *DefaultServer) upstreamCallbacks( s.registerHandler([]string{nbdns.RootZone}, handler, priority) } + // Always apply host config when nameserver reactivates, regardless of batch mode s.applyHostConfig() s.updateNSState(nsGroup, nil, true) diff --git a/client/internal/dns/server_export_test.go b/client/internal/dns/server_export_test.go index 1fa343b52..25d08d698 100644 --- a/client/internal/dns/server_export_test.go +++ b/client/internal/dns/server_export_test.go @@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) { t.Errorf("invalid dns server instance: %s", err) } - if srvB != srv { + mockSrvB, ok := srvB.(*MockServer) + if !ok { + t.Errorf("returned server is not a MockServer") + } + + if mockSrvB != srv { t.Errorf("mismatch dns instances") } } diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 077b9521b..9afe2049d 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -346,6 +346,23 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { } var merr *multierror.Error + + // Begin batch mode to avoid calling applyHostConfig() after each DNS handler operation + batchStarted := false + if m.dnsServer != nil { + m.dnsServer.BeginBatch() + batchStarted = true + defer func() { + if merr != nil { + // On error, cancel batch to discard partial DNS state + m.dnsServer.CancelBatch() + } else { + // On success, apply accumulated DNS changes + m.dnsServer.EndBatch() + } + }() + } + for id, handler := range toRemove { if err := handler.RemoveRoute(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove route %s: %w", handler.String(), err)) @@ -376,6 +393,7 @@ func (m *DefaultManager) updateSystemRoutes(newRoutes route.HAMap) error { m.activeRoutes[id] = handler } + _ = batchStarted // Mark as used return nberrors.FormatErrorOrNil(merr) }