From 101c813e9846bc30ce6c00766e6befdaf7dc965c Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 30 Jan 2026 17:42:14 +0800 Subject: [PATCH] [client] Add macOS default resolvers as fallback (#5201) --- client/internal/dns/host_darwin.go | 25 +++- client/internal/dns/host_darwin_test.go | 166 ++++++++++++++++++++++++ client/internal/dns/server.go | 3 +- client/internal/dns/test/mock.go | 6 + 4 files changed, 196 insertions(+), 4 deletions(-) diff --git a/client/internal/dns/host_darwin.go b/client/internal/dns/host_darwin.go index 71badf0d4..af84c8a85 100644 --- a/client/internal/dns/host_darwin.go +++ b/client/internal/dns/host_darwin.go @@ -9,8 +9,10 @@ import ( "io" "net/netip" "os/exec" + "slices" "strconv" "strings" + "sync" log "github.com/sirupsen/logrus" "golang.org/x/exp/maps" @@ -38,6 +40,9 @@ const ( type systemConfigurator struct { createdKeys map[string]struct{} systemDNSSettings SystemDNSSettings + + mu sync.RWMutex + origNameservers []netip.Addr } func newHostManager() (*systemConfigurator, error) { @@ -218,6 +223,7 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { } var dnsSettings SystemDNSSettings + var serverAddresses []netip.Addr inSearchDomainsArray := false inServerAddressesArray := false @@ -244,9 +250,12 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { dnsSettings.Domains = append(dnsSettings.Domains, searchDomain) } else if inServerAddressesArray { address := strings.Split(line, " : ")[1] - if ip, err := netip.ParseAddr(address); err == nil && ip.Is4() { - dnsSettings.ServerIP = ip.Unmap() - inServerAddressesArray = false // Stop reading after finding the first IPv4 address + if ip, err := netip.ParseAddr(address); err == nil && !ip.IsUnspecified() { + ip = ip.Unmap() + serverAddresses = append(serverAddresses, ip) + if !dnsSettings.ServerIP.IsValid() && ip.Is4() { + dnsSettings.ServerIP = ip + } } } } @@ -258,9 +267,19 @@ func (s *systemConfigurator) getSystemDNSSettings() (SystemDNSSettings, error) { // default to 53 port dnsSettings.ServerPort = DefaultPort + s.mu.Lock() + s.origNameservers = serverAddresses + s.mu.Unlock() + return dnsSettings, nil } +func (s *systemConfigurator) getOriginalNameservers() []netip.Addr { + s.mu.RLock() + defer s.mu.RUnlock() + return slices.Clone(s.origNameservers) +} + func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error { err := s.addDNSState(key, domains, ip, port, true) if err != nil { diff --git a/client/internal/dns/host_darwin_test.go b/client/internal/dns/host_darwin_test.go index c4efd17b0..28915de65 100644 --- a/client/internal/dns/host_darwin_test.go +++ b/client/internal/dns/host_darwin_test.go @@ -109,3 +109,169 @@ func removeTestDNSKey(key string) error { _, err := cmd.CombinedOutput() return err } + +func TestGetOriginalNameservers(t *testing.T) { + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + origNameservers: []netip.Addr{ + netip.MustParseAddr("8.8.8.8"), + netip.MustParseAddr("1.1.1.1"), + }, + } + + servers := configurator.getOriginalNameservers() + assert.Len(t, servers, 2) + assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0]) + assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1]) +} + +func TestGetOriginalNameserversFromSystem(t *testing.T) { + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + + servers := configurator.getOriginalNameservers() + + require.NotEmpty(t, servers, "expected at least one DNS server from system configuration") + + for _, server := range servers { + assert.True(t, server.IsValid(), "server address should be valid") + assert.False(t, server.IsUnspecified(), "server address should not be unspecified") + } + + t.Logf("found %d original nameservers: %v", len(servers), servers) +} + +func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) { + t.Helper() + + tmpDir := t.TempDir() + stateFile := filepath.Join(tmpDir, "state.json") + sm := statemanager.New(stateFile) + sm.RegisterState(&ShutdownState{}) + sm.Start() + + configurator := &systemConfigurator{ + createdKeys: make(map[string]struct{}), + } + + searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) + matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) + localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) + + cleanup := func() { + _ = sm.Stop(context.Background()) + for _, key := range []string{searchKey, matchKey, localKey} { + _ = removeTestDNSKey(key) + } + } + + return configurator, sm, cleanup +} + +func TestOriginalNameserversNoTransition(t *testing.T) { + netbirdIP := netip.MustParseAddr("100.64.0.1") + + testCases := []struct { + name string + routeAll bool + }{ + {"routeall_false", false}, + {"routeall_true", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configurator, sm, cleanup := setupTestConfigurator(t) + defer cleanup() + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + initialServers := configurator.getOriginalNameservers() + t.Logf("Initial servers: %v", initialServers) + require.NotEmpty(t, initialServers) + + for _, srv := range initialServers { + require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP") + } + + config := HostDNSConfig{ + ServerIP: netbirdIP, + ServerPort: 53, + RouteAll: tc.routeAll, + Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, + } + + for i := 1; i <= 2; i++ { + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + + servers := configurator.getOriginalNameservers() + t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers) + assert.Equal(t, initialServers, servers) + } + }) + } +} + +func TestOriginalNameserversRouteAllTransition(t *testing.T) { + netbirdIP := netip.MustParseAddr("100.64.0.1") + + testCases := []struct { + name string + initialRoute bool + }{ + {"start_with_routeall_false", false}, + {"start_with_routeall_true", true}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + configurator, sm, cleanup := setupTestConfigurator(t) + defer cleanup() + + _, err := configurator.getSystemDNSSettings() + require.NoError(t, err) + initialServers := configurator.getOriginalNameservers() + t.Logf("Initial servers: %v", initialServers) + require.NotEmpty(t, initialServers) + + config := HostDNSConfig{ + ServerIP: netbirdIP, + ServerPort: 53, + RouteAll: tc.initialRoute, + Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, + } + + // First apply + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers := configurator.getOriginalNameservers() + t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers) + assert.Equal(t, initialServers, servers) + + // Toggle RouteAll + config.RouteAll = !tc.initialRoute + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers = configurator.getOriginalNameservers() + t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers) + assert.Equal(t, initialServers, servers) + + // Toggle back + config.RouteAll = tc.initialRoute + err = configurator.applyDNSConfig(config, sm) + require.NoError(t, err) + servers = configurator.getOriginalNameservers() + t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers) + assert.Equal(t, initialServers, servers) + + for _, srv := range servers { + assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP") + } + }) + } +} diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index 1ce7bf1c6..4d4fcc06e 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -615,7 +615,7 @@ func (s *DefaultServer) applyHostConfig() { s.registerFallback(config) } -// registerFallback registers original nameservers as low-priority fallback handlers +// registerFallback registers original nameservers as low-priority fallback handlers. func (s *DefaultServer) registerFallback(config HostDNSConfig) { hostMgrWithNS, ok := s.hostManager.(hostManagerWithOriginalNS) if !ok { @@ -624,6 +624,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) { originalNameservers := hostMgrWithNS.getOriginalNameservers() if len(originalNameservers) == 0 { + s.deregisterHandler([]string{nbdns.RootZone}, PriorityFallback) return } diff --git a/client/internal/dns/test/mock.go b/client/internal/dns/test/mock.go index 1db452805..8d16689bf 100644 --- a/client/internal/dns/test/mock.go +++ b/client/internal/dns/test/mock.go @@ -8,15 +8,21 @@ import ( type MockResponseWriter struct { WriteMsgFunc func(m *dns.Msg) error + lastResponse *dns.Msg } func (rw *MockResponseWriter) WriteMsg(m *dns.Msg) error { + rw.lastResponse = m if rw.WriteMsgFunc != nil { return rw.WriteMsgFunc(m) } return nil } +func (rw *MockResponseWriter) GetLastResponse() *dns.Msg { + return rw.lastResponse +} + func (rw *MockResponseWriter) LocalAddr() net.Addr { return nil } func (rw *MockResponseWriter) RemoteAddr() net.Addr { return nil } func (rw *MockResponseWriter) Write([]byte) (int, error) { return 0, nil }