Files
netbird/client/internal/dns/host_darwin_test.go
Maycon Santos d1ead2265b [client] Batch macOS DNS domains to avoid truncation (#5368)
* [client] Batch macOS DNS domains across multiple scutil keys to avoid truncation

scutil has undocumented limits: 99-element cap on d.add arrays and ~2048
  byte value buffer for SupplementalMatchDomains. Users with 60+ domains
  hit silent domain loss. This applies the same batching approach used on
  Windows (nrptMaxDomainsPerRule=50), splitting domains into indexed
  resolver keys (NetBird-Match-0, NetBird-Match-1, etc.) with 50-element
  and 1500-byte limits per key.

* check for all keys on getRemovableKeysWithDefaults

* use multi error
2026-02-18 19:14:09 +01:00

496 lines
13 KiB
Go

//go:build !ios
package dns
import (
"bufio"
"bytes"
"context"
"fmt"
"net/netip"
"os/exec"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
defer func() {
require.NoError(t, sm.Stop(context.Background()))
}()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
config := HostDNSConfig{
ServerIP: netip.MustParseAddr("100.64.0.1"),
ServerPort: 53,
RouteAll: true,
Domains: []DomainConfig{
{Domain: "example.com", MatchOnly: true},
},
}
err := configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
require.NoError(t, sm.PersistState(context.Background()))
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
// Collect all created keys for cleanup verification
createdKeys := make([]string, 0, len(configurator.createdKeys))
for key := range configurator.createdKeys {
createdKeys = append(createdKeys, key)
}
defer func() {
for _, key := range createdKeys {
_ = removeTestDNSKey(key)
}
_ = removeTestDNSKey(localKey)
}()
for _, key := range createdKeys {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
if exists {
t.Logf("Key %s exists before cleanup", key)
}
}
sm2 := statemanager.New(stateFile)
sm2.RegisterState(&ShutdownState{})
err = sm2.LoadState(&ShutdownState{})
require.NoError(t, err)
state := sm2.GetState(&ShutdownState{})
if state == nil {
t.Skip("State not saved, skipping cleanup test")
}
shutdownState, ok := state.(*ShutdownState)
require.True(t, ok)
err = shutdownState.Cleanup()
require.NoError(t, err)
for _, key := range createdKeys {
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
}
}
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
func generateShortDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
label := ""
n := i
for {
label = string(rune('a'+n%26)) + label
n = n/26 - 1
if n < 0 {
break
}
}
domains = append(domains, label+".com")
}
return domains
}
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
func generateLongDomains(count int) []string {
domains := make([]string, 0, count)
for i := range count {
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
}
return domains
}
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
func readDomainsFromKey(t *testing.T, key string) []string {
t.Helper()
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
out, err := cmd.Output()
require.NoError(t, err, "scutil show should succeed")
var domains []string
inArray := false
scanner := bufio.NewScanner(bytes.NewReader(out))
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
inArray = true
continue
}
if inArray {
if line == "}" {
break
}
// lines look like: "0 : a.com"
parts := strings.SplitN(line, " : ", 2)
if len(parts) == 2 {
domains = append(domains, parts[1])
}
}
}
require.NoError(t, scanner.Err())
return domains
}
func TestSplitDomainsIntoBatches(t *testing.T) {
tests := []struct {
name string
domains []string
expectedCount int
checkAllPresent bool
}{
{
name: "empty",
domains: nil,
expectedCount: 0,
},
{
name: "under_limit",
domains: generateShortDomains(10),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "at_element_limit",
domains: generateShortDomains(50),
expectedCount: 1,
checkAllPresent: true,
},
{
name: "over_element_limit",
domains: generateShortDomains(51),
expectedCount: 2,
checkAllPresent: true,
},
{
name: "triple_element_limit",
domains: generateShortDomains(150),
expectedCount: 3,
checkAllPresent: true,
},
{
name: "long_domains_hit_byte_limit",
domains: generateLongDomains(50),
checkAllPresent: true,
},
{
name: "500_short_domains",
domains: generateShortDomains(500),
expectedCount: 10,
checkAllPresent: true,
},
{
name: "500_long_domains",
domains: generateLongDomains(500),
checkAllPresent: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
batches := splitDomainsIntoBatches(tc.domains)
if tc.expectedCount > 0 {
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
}
// Verify each batch respects limits
for i, batch := range batches {
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
"batch %d exceeds element limit", i)
totalBytes := 0
for j, d := range batch {
if j > 0 {
totalBytes++
}
totalBytes += len(d)
}
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
}
if tc.checkAllPresent {
var all []string
for _, batch := range batches {
all = append(all, batch...)
}
assert.Equal(t, tc.domains, all, "all domains should be present in order")
}
})
}
}
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
// and verifies all domains are readable across multiple scutil keys.
func TestMatchDomainBatching(t *testing.T) {
if testing.Short() {
t.Skip("skipping scutil integration test in short mode")
}
testCases := []struct {
name string
count int
generator func(int) []string
}{
{"short_10", 10, generateShortDomains},
{"short_50", 50, generateShortDomains},
{"short_100", 100, generateShortDomains},
{"short_200", 200, generateShortDomains},
{"short_500", 500, generateShortDomains},
{"long_10", 10, generateLongDomains},
{"long_50", 50, generateLongDomains},
{"long_100", 100, generateLongDomains},
{"long_200", 200, generateLongDomains},
{"long_500", 500, generateLongDomains},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
defer func() {
for key := range configurator.createdKeys {
_ = removeTestDNSKey(key)
}
}()
domains := tc.generator(tc.count)
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
require.NoError(t, err)
batches := splitDomainsIntoBatches(domains)
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
// Read back all domains from all batched keys
var got []string
for i := range batches {
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
exists, err := checkDNSKeyExists(key)
require.NoError(t, err)
require.True(t, exists, "key %s should exist", key)
got = append(got, readDomainsFromKey(t, key)...)
}
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
assert.Equal(t, tc.count, len(got), "all domains should be readable")
assert.Equal(t, domains, got, "domains should match in order")
})
}
}
func checkDNSKeyExists(key string) (bool, error) {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
output, err := cmd.CombinedOutput()
if err != nil {
if strings.Contains(string(output), "No such key") {
return false, nil
}
return false, err
}
return !strings.Contains(string(output), "No such key"), nil
}
func removeTestDNSKey(key string) error {
cmd := exec.Command(scutilPath)
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
_, err := cmd.CombinedOutput()
return err
}
func TestGetOriginalNameservers(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
origNameservers: []netip.Addr{
netip.MustParseAddr("8.8.8.8"),
netip.MustParseAddr("1.1.1.1"),
},
}
servers := configurator.getOriginalNameservers()
assert.Len(t, servers, 2)
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
}
func TestGetOriginalNameserversFromSystem(t *testing.T) {
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
for _, server := range servers {
assert.True(t, server.IsValid(), "server address should be valid")
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
}
t.Logf("found %d original nameservers: %v", len(servers), servers)
}
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
t.Helper()
tmpDir := t.TempDir()
stateFile := filepath.Join(tmpDir, "state.json")
sm := statemanager.New(stateFile)
sm.RegisterState(&ShutdownState{})
sm.Start()
configurator := &systemConfigurator{
createdKeys: make(map[string]struct{}),
}
cleanup := func() {
_ = sm.Stop(context.Background())
for key := range configurator.createdKeys {
_ = removeTestDNSKey(key)
}
// Also clean up old-format keys and local key in case they exist
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
}
return configurator, sm, cleanup
}
func TestOriginalNameserversNoTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
routeAll bool
}{
{"routeall_false", false},
{"routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
for _, srv := range initialServers {
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
}
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.routeAll,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
for i := 1; i <= 2; i++ {
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
assert.Equal(t, initialServers, servers)
}
})
}
}
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
netbirdIP := netip.MustParseAddr("100.64.0.1")
testCases := []struct {
name string
initialRoute bool
}{
{"start_with_routeall_false", false},
{"start_with_routeall_true", true},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
configurator, sm, cleanup := setupTestConfigurator(t)
defer cleanup()
_, err := configurator.getSystemDNSSettings()
require.NoError(t, err)
initialServers := configurator.getOriginalNameservers()
t.Logf("Initial servers: %v", initialServers)
require.NotEmpty(t, initialServers)
config := HostDNSConfig{
ServerIP: netbirdIP,
ServerPort: 53,
RouteAll: tc.initialRoute,
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
}
// First apply
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers := configurator.getOriginalNameservers()
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
assert.Equal(t, initialServers, servers)
// Toggle RouteAll
config.RouteAll = !tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
// Toggle back
config.RouteAll = tc.initialRoute
err = configurator.applyDNSConfig(config, sm)
require.NoError(t, err)
servers = configurator.getOriginalNameservers()
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
assert.Equal(t, initialServers, servers)
for _, srv := range servers {
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
}
})
}
}