diff --git a/client/internal/dns/dnsfw/config.go b/client/internal/dns/dnsfw/config.go index 6efa6b371..372a34686 100644 --- a/client/internal/dns/dnsfw/config.go +++ b/client/internal/dns/dnsfw/config.go @@ -21,16 +21,16 @@ const ( EnvStrict = "NB_DNS_FIREWALL_STRICT" ) +// defaultBlockedPorts are the well-known DNS ports we block for non-netbird +// processes: 53 (plain DNS) and 853 (DNS-over-TLS). +var defaultBlockedPorts = []uint16{53, 853} + // strictMode reports whether strict mode is enabled via env. func strictMode() bool { v, _ := strconv.ParseBool(os.Getenv(EnvStrict)) return v } -// defaultBlockedPorts are the well-known DNS ports we block for non-netbird -// processes: 53 (plain DNS) and 853 (DNS-over-TLS). -var defaultBlockedPorts = []uint16{53, 853} - // blockedPorts returns the effective port list, honoring env overrides. // A nil return means the firewall should not be installed. func blockedPorts() []uint16 { diff --git a/client/internal/dns/dnsfw/dnsfw_other.go b/client/internal/dns/dnsfw/dnsfw_other.go index 458c443a1..142a8417e 100644 --- a/client/internal/dns/dnsfw/dnsfw_other.go +++ b/client/internal/dns/dnsfw/dnsfw_other.go @@ -4,12 +4,12 @@ package dnsfw import "net/netip" -// New returns a no-op manager on non-Windows platforms. -func New() Manager { - return noopManager{} -} - type noopManager struct{} func (noopManager) Enable(string, netip.Addr) error { return nil } func (noopManager) Disable() error { return nil } + +// New returns a no-op manager on non-Windows platforms. +func New() Manager { + return noopManager{} +} diff --git a/client/internal/dns/dnsfw/dnsfw_windows.go b/client/internal/dns/dnsfw/dnsfw_windows.go index f9dc53734..e1ae412a7 100644 --- a/client/internal/dns/dnsfw/dnsfw_windows.go +++ b/client/internal/dns/dnsfw/dnsfw_windows.go @@ -13,10 +13,10 @@ import ( "golang.org/x/sys/windows" ) -// New returns a Windows DNS firewall manager backed by WFP. -func New() Manager { - return &windowsManager{} -} +var ( + modIphlpapi = windows.NewLazyDLL("iphlpapi.dll") + procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid") +) type windowsManager struct { mu sync.Mutex @@ -80,6 +80,24 @@ func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error return nil } +func (m *windowsManager) Disable() error { + m.mu.Lock() + defer m.mu.Unlock() + return m.disableLocked() +} + +func (m *windowsManager) disableLocked() error { + if m.session == 0 { + return nil + } + if err := closeSession(m.session); err != nil { + return fmt.Errorf("close wfp session: %w", err) + } + m.session = 0 + log.Info("dns firewall removed") + return nil +} + // failOrLog returns err unchanged in strict mode. In non-strict mode the // error is logged and nil is returned. func (m *windowsManager) failOrLog(strict bool, err error) error { @@ -90,6 +108,11 @@ func (m *windowsManager) failOrLog(strict bool, err error) error { return nil } +// New returns a Windows DNS firewall manager backed by WFP. +func New() Manager { + return &windowsManager{} +} + // luidFromGUID converts a Windows interface GUID string to its LUID. func luidFromGUID(ifaceGUID string) (luid uint64, err error) { defer func() { @@ -111,26 +134,3 @@ func luidFromGUID(ifaceGUID string) (luid uint64, err error) { } return luid, nil } - -var ( - modIphlpapi = windows.NewLazyDLL("iphlpapi.dll") - procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid") -) - -func (m *windowsManager) Disable() error { - m.mu.Lock() - defer m.mu.Unlock() - return m.disableLocked() -} - -func (m *windowsManager) disableLocked() error { - if m.session == 0 { - return nil - } - if err := closeSession(m.session); err != nil { - return fmt.Errorf("close wfp session: %w", err) - } - m.session = 0 - log.Info("dns firewall removed") - return nil -} diff --git a/client/internal/dns/host_windows.go b/client/internal/dns/host_windows.go index d644b428d..2a1cfed1f 100644 --- a/client/internal/dns/host_windows.go +++ b/client/internal/dns/host_windows.go @@ -172,24 +172,8 @@ func (r *registryConfigurator) disableWINSForInterface() error { } func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error { - if config.RouteAll { - if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil { - return fmt.Errorf("dns firewall: %w", err) - } - if err := r.addDNSSetupForAll(config.ServerIP); err != nil { - return fmt.Errorf("add dns setup: %w", err) - } - } else { - if err := r.dnsFirewall.Disable(); err != nil { - log.Errorf("disable dns firewall: %v", err) - } - if r.routingAll { - if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil { - return fmt.Errorf("delete interface registry key property: %w", err) - } - r.routingAll = false - log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) - } + if err := r.applyRouteAll(config); err != nil { + return err } r.updateState(stateManager) @@ -231,6 +215,31 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager return nil } +func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error { + if config.RouteAll { + if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil { + return fmt.Errorf("dns firewall: %w", err) + } + if err := r.addDNSSetupForAll(config.ServerIP); err != nil { + return fmt.Errorf("add dns setup: %w", err) + } + return nil + } + + if err := r.dnsFirewall.Disable(); err != nil { + log.Errorf("disable dns firewall: %v", err) + } + if !r.routingAll { + return nil + } + if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil { + return fmt.Errorf("delete interface registry key property: %w", err) + } + r.routingAll = false + log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP) + return nil +} + func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) { if err := stateManager.UpdateState(&ShutdownState{ Guid: r.guid,