//go:build !ios package dns import ( "context" "net/netip" "os/exec" "path/filepath" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/client/internal/statemanager" ) func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) { if testing.Short() { t.Skip("skipping scutil integration test in short mode") } tmpDir := t.TempDir() stateFile := filepath.Join(tmpDir, "state.json") sm := statemanager.New(stateFile) sm.RegisterState(&ShutdownState{}) sm.Start() defer func() { require.NoError(t, sm.Stop(context.Background())) }() configurator := &systemConfigurator{ createdKeys: make(map[string]struct{}), } config := HostDNSConfig{ ServerIP: netip.MustParseAddr("100.64.0.1"), ServerPort: 53, RouteAll: true, Domains: []DomainConfig{ {Domain: "example.com", MatchOnly: true}, }, } err := configurator.applyDNSConfig(config, sm) require.NoError(t, err) require.NoError(t, sm.PersistState(context.Background())) searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) defer func() { for _, key := range []string{searchKey, matchKey, localKey} { _ = removeTestDNSKey(key) } }() for _, key := range []string{searchKey, matchKey, localKey} { exists, err := checkDNSKeyExists(key) require.NoError(t, err) if exists { t.Logf("Key %s exists before cleanup", key) } } sm2 := statemanager.New(stateFile) sm2.RegisterState(&ShutdownState{}) err = sm2.LoadState(&ShutdownState{}) require.NoError(t, err) state := sm2.GetState(&ShutdownState{}) if state == nil { t.Skip("State not saved, skipping cleanup test") } shutdownState, ok := state.(*ShutdownState) require.True(t, ok) err = shutdownState.Cleanup() require.NoError(t, err) for _, key := range []string{searchKey, matchKey, localKey} { exists, err := checkDNSKeyExists(key) require.NoError(t, err) assert.False(t, exists, "Key %s should NOT exist after cleanup", key) } } func checkDNSKeyExists(key string) (bool, error) { cmd := exec.Command(scutilPath) cmd.Stdin = strings.NewReader("show " + key + "\nquit\n") output, err := cmd.CombinedOutput() if err != nil { if strings.Contains(string(output), "No such key") { return false, nil } return false, err } return !strings.Contains(string(output), "No such key"), nil } func removeTestDNSKey(key string) error { cmd := exec.Command(scutilPath) cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n") _, err := cmd.CombinedOutput() return err } func TestGetOriginalNameservers(t *testing.T) { configurator := &systemConfigurator{ createdKeys: make(map[string]struct{}), origNameservers: []netip.Addr{ netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("1.1.1.1"), }, } servers := configurator.getOriginalNameservers() assert.Len(t, servers, 2) assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0]) assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1]) } func TestGetOriginalNameserversFromSystem(t *testing.T) { configurator := &systemConfigurator{ createdKeys: make(map[string]struct{}), } _, err := configurator.getSystemDNSSettings() require.NoError(t, err) servers := configurator.getOriginalNameservers() require.NotEmpty(t, servers, "expected at least one DNS server from system configuration") for _, server := range servers { assert.True(t, server.IsValid(), "server address should be valid") assert.False(t, server.IsUnspecified(), "server address should not be unspecified") } t.Logf("found %d original nameservers: %v", len(servers), servers) } func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) { t.Helper() tmpDir := t.TempDir() stateFile := filepath.Join(tmpDir, "state.json") sm := statemanager.New(stateFile) sm.RegisterState(&ShutdownState{}) sm.Start() configurator := &systemConfigurator{ createdKeys: make(map[string]struct{}), } searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix) matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix) localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix) cleanup := func() { _ = sm.Stop(context.Background()) for _, key := range []string{searchKey, matchKey, localKey} { _ = removeTestDNSKey(key) } } return configurator, sm, cleanup } func TestOriginalNameserversNoTransition(t *testing.T) { netbirdIP := netip.MustParseAddr("100.64.0.1") testCases := []struct { name string routeAll bool }{ {"routeall_false", false}, {"routeall_true", true}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { configurator, sm, cleanup := setupTestConfigurator(t) defer cleanup() _, err := configurator.getSystemDNSSettings() require.NoError(t, err) initialServers := configurator.getOriginalNameservers() t.Logf("Initial servers: %v", initialServers) require.NotEmpty(t, initialServers) for _, srv := range initialServers { require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP") } config := HostDNSConfig{ ServerIP: netbirdIP, ServerPort: 53, RouteAll: tc.routeAll, Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, } for i := 1; i <= 2; i++ { err = configurator.applyDNSConfig(config, sm) require.NoError(t, err) servers := configurator.getOriginalNameservers() t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers) assert.Equal(t, initialServers, servers) } }) } } func TestOriginalNameserversRouteAllTransition(t *testing.T) { netbirdIP := netip.MustParseAddr("100.64.0.1") testCases := []struct { name string initialRoute bool }{ {"start_with_routeall_false", false}, {"start_with_routeall_true", true}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { configurator, sm, cleanup := setupTestConfigurator(t) defer cleanup() _, err := configurator.getSystemDNSSettings() require.NoError(t, err) initialServers := configurator.getOriginalNameservers() t.Logf("Initial servers: %v", initialServers) require.NotEmpty(t, initialServers) config := HostDNSConfig{ ServerIP: netbirdIP, ServerPort: 53, RouteAll: tc.initialRoute, Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}}, } // First apply err = configurator.applyDNSConfig(config, sm) require.NoError(t, err) servers := configurator.getOriginalNameservers() t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers) assert.Equal(t, initialServers, servers) // Toggle RouteAll config.RouteAll = !tc.initialRoute err = configurator.applyDNSConfig(config, sm) require.NoError(t, err) servers = configurator.getOriginalNameservers() t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers) assert.Equal(t, initialServers, servers) // Toggle back config.RouteAll = tc.initialRoute err = configurator.applyDNSConfig(config, sm) require.NoError(t, err) servers = configurator.getOriginalNameservers() t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers) assert.Equal(t, initialServers, servers) for _, srv := range servers { assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP") } }) } }