mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
278 lines
7.4 KiB
Go
278 lines
7.4 KiB
Go
//go:build !ios
|
|
|
|
package dns
|
|
|
|
import (
|
|
"context"
|
|
"net/netip"
|
|
"os/exec"
|
|
"path/filepath"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
|
)
|
|
|
|
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("skipping scutil integration test in short mode")
|
|
}
|
|
|
|
tmpDir := t.TempDir()
|
|
stateFile := filepath.Join(tmpDir, "state.json")
|
|
|
|
sm := statemanager.New(stateFile)
|
|
sm.RegisterState(&ShutdownState{})
|
|
sm.Start()
|
|
defer func() {
|
|
require.NoError(t, sm.Stop(context.Background()))
|
|
}()
|
|
|
|
configurator := &systemConfigurator{
|
|
createdKeys: make(map[string]struct{}),
|
|
}
|
|
|
|
config := HostDNSConfig{
|
|
ServerIP: netip.MustParseAddr("100.64.0.1"),
|
|
ServerPort: 53,
|
|
RouteAll: true,
|
|
Domains: []DomainConfig{
|
|
{Domain: "example.com", MatchOnly: true},
|
|
},
|
|
}
|
|
|
|
err := configurator.applyDNSConfig(config, sm)
|
|
require.NoError(t, err)
|
|
|
|
require.NoError(t, sm.PersistState(context.Background()))
|
|
|
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
|
|
|
defer func() {
|
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
|
_ = removeTestDNSKey(key)
|
|
}
|
|
}()
|
|
|
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
|
exists, err := checkDNSKeyExists(key)
|
|
require.NoError(t, err)
|
|
if exists {
|
|
t.Logf("Key %s exists before cleanup", key)
|
|
}
|
|
}
|
|
|
|
sm2 := statemanager.New(stateFile)
|
|
sm2.RegisterState(&ShutdownState{})
|
|
err = sm2.LoadState(&ShutdownState{})
|
|
require.NoError(t, err)
|
|
|
|
state := sm2.GetState(&ShutdownState{})
|
|
if state == nil {
|
|
t.Skip("State not saved, skipping cleanup test")
|
|
}
|
|
|
|
shutdownState, ok := state.(*ShutdownState)
|
|
require.True(t, ok)
|
|
|
|
err = shutdownState.Cleanup()
|
|
require.NoError(t, err)
|
|
|
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
|
exists, err := checkDNSKeyExists(key)
|
|
require.NoError(t, err)
|
|
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
|
}
|
|
}
|
|
|
|
func checkDNSKeyExists(key string) (bool, error) {
|
|
cmd := exec.Command(scutilPath)
|
|
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
|
output, err := cmd.CombinedOutput()
|
|
if err != nil {
|
|
if strings.Contains(string(output), "No such key") {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
return !strings.Contains(string(output), "No such key"), nil
|
|
}
|
|
|
|
func removeTestDNSKey(key string) error {
|
|
cmd := exec.Command(scutilPath)
|
|
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
|
|
_, err := cmd.CombinedOutput()
|
|
return err
|
|
}
|
|
|
|
func TestGetOriginalNameservers(t *testing.T) {
|
|
configurator := &systemConfigurator{
|
|
createdKeys: make(map[string]struct{}),
|
|
origNameservers: []netip.Addr{
|
|
netip.MustParseAddr("8.8.8.8"),
|
|
netip.MustParseAddr("1.1.1.1"),
|
|
},
|
|
}
|
|
|
|
servers := configurator.getOriginalNameservers()
|
|
assert.Len(t, servers, 2)
|
|
assert.Equal(t, netip.MustParseAddr("8.8.8.8"), servers[0])
|
|
assert.Equal(t, netip.MustParseAddr("1.1.1.1"), servers[1])
|
|
}
|
|
|
|
func TestGetOriginalNameserversFromSystem(t *testing.T) {
|
|
configurator := &systemConfigurator{
|
|
createdKeys: make(map[string]struct{}),
|
|
}
|
|
|
|
_, err := configurator.getSystemDNSSettings()
|
|
require.NoError(t, err)
|
|
|
|
servers := configurator.getOriginalNameservers()
|
|
|
|
require.NotEmpty(t, servers, "expected at least one DNS server from system configuration")
|
|
|
|
for _, server := range servers {
|
|
assert.True(t, server.IsValid(), "server address should be valid")
|
|
assert.False(t, server.IsUnspecified(), "server address should not be unspecified")
|
|
}
|
|
|
|
t.Logf("found %d original nameservers: %v", len(servers), servers)
|
|
}
|
|
|
|
func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Manager, func()) {
|
|
t.Helper()
|
|
|
|
tmpDir := t.TempDir()
|
|
stateFile := filepath.Join(tmpDir, "state.json")
|
|
sm := statemanager.New(stateFile)
|
|
sm.RegisterState(&ShutdownState{})
|
|
sm.Start()
|
|
|
|
configurator := &systemConfigurator{
|
|
createdKeys: make(map[string]struct{}),
|
|
}
|
|
|
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
|
|
|
cleanup := func() {
|
|
_ = sm.Stop(context.Background())
|
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
|
_ = removeTestDNSKey(key)
|
|
}
|
|
}
|
|
|
|
return configurator, sm, cleanup
|
|
}
|
|
|
|
func TestOriginalNameserversNoTransition(t *testing.T) {
|
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
routeAll bool
|
|
}{
|
|
{"routeall_false", false},
|
|
{"routeall_true", true},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
|
defer cleanup()
|
|
|
|
_, err := configurator.getSystemDNSSettings()
|
|
require.NoError(t, err)
|
|
initialServers := configurator.getOriginalNameservers()
|
|
t.Logf("Initial servers: %v", initialServers)
|
|
require.NotEmpty(t, initialServers)
|
|
|
|
for _, srv := range initialServers {
|
|
require.NotEqual(t, netbirdIP, srv, "initial servers should not contain NetBird IP")
|
|
}
|
|
|
|
config := HostDNSConfig{
|
|
ServerIP: netbirdIP,
|
|
ServerPort: 53,
|
|
RouteAll: tc.routeAll,
|
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
|
}
|
|
|
|
for i := 1; i <= 2; i++ {
|
|
err = configurator.applyDNSConfig(config, sm)
|
|
require.NoError(t, err)
|
|
|
|
servers := configurator.getOriginalNameservers()
|
|
t.Logf("After apply %d (RouteAll=%v): %v", i, tc.routeAll, servers)
|
|
assert.Equal(t, initialServers, servers)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOriginalNameserversRouteAllTransition(t *testing.T) {
|
|
netbirdIP := netip.MustParseAddr("100.64.0.1")
|
|
|
|
testCases := []struct {
|
|
name string
|
|
initialRoute bool
|
|
}{
|
|
{"start_with_routeall_false", false},
|
|
{"start_with_routeall_true", true},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
configurator, sm, cleanup := setupTestConfigurator(t)
|
|
defer cleanup()
|
|
|
|
_, err := configurator.getSystemDNSSettings()
|
|
require.NoError(t, err)
|
|
initialServers := configurator.getOriginalNameservers()
|
|
t.Logf("Initial servers: %v", initialServers)
|
|
require.NotEmpty(t, initialServers)
|
|
|
|
config := HostDNSConfig{
|
|
ServerIP: netbirdIP,
|
|
ServerPort: 53,
|
|
RouteAll: tc.initialRoute,
|
|
Domains: []DomainConfig{{Domain: "example.com", MatchOnly: true}},
|
|
}
|
|
|
|
// First apply
|
|
err = configurator.applyDNSConfig(config, sm)
|
|
require.NoError(t, err)
|
|
servers := configurator.getOriginalNameservers()
|
|
t.Logf("After first apply (RouteAll=%v): %v", tc.initialRoute, servers)
|
|
assert.Equal(t, initialServers, servers)
|
|
|
|
// Toggle RouteAll
|
|
config.RouteAll = !tc.initialRoute
|
|
err = configurator.applyDNSConfig(config, sm)
|
|
require.NoError(t, err)
|
|
servers = configurator.getOriginalNameservers()
|
|
t.Logf("After toggle (RouteAll=%v): %v", config.RouteAll, servers)
|
|
assert.Equal(t, initialServers, servers)
|
|
|
|
// Toggle back
|
|
config.RouteAll = tc.initialRoute
|
|
err = configurator.applyDNSConfig(config, sm)
|
|
require.NoError(t, err)
|
|
servers = configurator.getOriginalNameservers()
|
|
t.Logf("After toggle back (RouteAll=%v): %v", config.RouteAll, servers)
|
|
assert.Equal(t, initialServers, servers)
|
|
|
|
for _, srv := range servers {
|
|
assert.NotEqual(t, netbirdIP, srv, "servers should not contain NetBird IP")
|
|
}
|
|
})
|
|
}
|
|
}
|