mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client] Fix macOS state-based dns cleanup (#4701)
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -50,28 +51,21 @@ func (s *systemConfigurator) supportCustomPort() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
|
||||||
var err error
|
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
searchDomains []string
|
searchDomains []string
|
||||||
matchDomains []string
|
matchDomains []string
|
||||||
)
|
)
|
||||||
|
|
||||||
err = s.recordSystemDNSSettings(true)
|
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
log.Errorf("unable to update record of System's DNS config: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
searchDomains = append(searchDomains, "\"\"")
|
searchDomains = append(searchDomains, "\"\"")
|
||||||
err = s.addLocalDNS()
|
if err := s.addLocalDNS(); err != nil {
|
||||||
if err != nil {
|
log.Warnf("failed to add local DNS: %v", err)
|
||||||
log.Infof("failed to enable split DNS")
|
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
@@ -86,6 +80,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
var err error
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||||
} else {
|
} else {
|
||||||
@@ -95,6 +90,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add match domains: %w", err)
|
return fmt.Errorf("add match domains: %w", err)
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
if len(searchDomains) != 0 {
|
if len(searchDomains) != 0 {
|
||||||
@@ -106,6 +102,7 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
s.updateState(stateManager)
|
||||||
|
|
||||||
if err := s.flushDNSCache(); err != nil {
|
if err := s.flushDNSCache(); err != nil {
|
||||||
log.Errorf("failed to flush DNS cache: %v", err)
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
@@ -114,6 +111,12 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *systemConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{CreatedKeys: maps.Keys(s.createdKeys)}); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) string() string {
|
func (s *systemConfigurator) string() string {
|
||||||
return "scutil"
|
return "scutil"
|
||||||
}
|
}
|
||||||
@@ -167,18 +170,20 @@ func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
|||||||
func (s *systemConfigurator) addLocalDNS() error {
|
func (s *systemConfigurator) addLocalDNS() error {
|
||||||
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||||
if err := s.recordSystemDNSSettings(true); err != nil {
|
if err := s.recordSystemDNSSettings(true); err != nil {
|
||||||
log.Errorf("Unable to get system DNS configuration")
|
|
||||||
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
return fmt.Errorf("recordSystemDNSSettings(): %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
if s.systemDNSSettings.ServerIP.IsValid() && len(s.systemDNSSettings.Domains) != 0 {
|
if !s.systemDNSSettings.ServerIP.IsValid() || len(s.systemDNSSettings.Domains) == 0 {
|
||||||
err := s.addSearchDomains(localKey, strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("couldn't add local network DNS conf: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Info("Not enabling local DNS server")
|
log.Info("Not enabling local DNS server")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.addSearchDomains(
|
||||||
|
localKey,
|
||||||
|
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
111
client/internal/dns/host_darwin_test.go
Normal file
111
client/internal/dns/host_darwin_test.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
//go:build !ios
|
||||||
|
|
||||||
|
package dns
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping scutil integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
stateFile := filepath.Join(tmpDir, "state.json")
|
||||||
|
|
||||||
|
sm := statemanager.New(stateFile)
|
||||||
|
sm.RegisterState(&ShutdownState{})
|
||||||
|
sm.Start()
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, sm.Stop(context.Background()))
|
||||||
|
}()
|
||||||
|
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: netip.MustParseAddr("100.64.0.1"),
|
||||||
|
ServerPort: 53,
|
||||||
|
RouteAll: true,
|
||||||
|
Domains: []DomainConfig{
|
||||||
|
{Domain: "example.com", MatchOnly: true},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := configurator.applyDNSConfig(config, sm)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.NoError(t, sm.PersistState(context.Background()))
|
||||||
|
|
||||||
|
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||||
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
if exists {
|
||||||
|
t.Logf("Key %s exists before cleanup", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sm2 := statemanager.New(stateFile)
|
||||||
|
sm2.RegisterState(&ShutdownState{})
|
||||||
|
err = sm2.LoadState(&ShutdownState{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
state := sm2.GetState(&ShutdownState{})
|
||||||
|
if state == nil {
|
||||||
|
t.Skip("State not saved, skipping cleanup test")
|
||||||
|
}
|
||||||
|
|
||||||
|
shutdownState, ok := state.(*ShutdownState)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
err = shutdownState.Cleanup()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkDNSKeyExists(key string) (bool, error) {
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(string(output), "No such key") {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return !strings.Contains(string(output), "No such key"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeTestDNSKey(key string) error {
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader("remove " + key + "\nquit\n")
|
||||||
|
_, err := cmd.CombinedOutput()
|
||||||
|
return err
|
||||||
|
}
|
||||||
@@ -179,13 +179,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
r.updateState(stateManager)
|
||||||
Guid: r.guid,
|
|
||||||
GPO: r.gpo,
|
|
||||||
NRPTEntryCount: r.nrptEntryCount,
|
|
||||||
}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var searchDomains, matchDomains []string
|
var searchDomains, matchDomains []string
|
||||||
for _, dConf := range config.Domains {
|
for _, dConf := range config.Domains {
|
||||||
@@ -212,13 +206,7 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
r.nrptEntryCount = 0
|
r.nrptEntryCount = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
r.updateState(stateManager)
|
||||||
Guid: r.guid,
|
|
||||||
GPO: r.gpo,
|
|
||||||
NRPTEntryCount: r.nrptEntryCount,
|
|
||||||
}); err != nil {
|
|
||||||
log.Errorf("failed to update shutdown state: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.updateSearchDomains(searchDomains); err != nil {
|
if err := r.updateSearchDomains(searchDomains); err != nil {
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
@@ -229,6 +217,16 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
|
||||||
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
|
Guid: r.guid,
|
||||||
|
GPO: r.gpo,
|
||||||
|
NRPTEntryCount: r.nrptEntryCount,
|
||||||
|
}); err != nil {
|
||||||
|
log.Errorf("failed to update shutdown state: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
||||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigNameServerKey, ip.String()); err != nil {
|
||||||
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
return fmt.Errorf("adding dns setup for all failed: %w", err)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ShutdownState struct {
|
type ShutdownState struct {
|
||||||
|
CreatedKeys []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *ShutdownState) Name() string {
|
func (s *ShutdownState) Name() string {
|
||||||
@@ -19,6 +20,10 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
return fmt.Errorf("create host manager: %w", err)
|
return fmt.Errorf("create host manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, key := range s.CreatedKeys {
|
||||||
|
manager.createdKeys[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
if err := manager.restoreUncleanShutdownDNS(); err != nil {
|
||||||
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
return fmt.Errorf("restore unclean shutdown dns: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user