Extract applyRouteAll helper and reorder package declarations

This commit is contained in:
Viktor Liu
2026-05-05 18:16:28 +02:00
parent 4810e79a00
commit 6a201d12b5
4 changed files with 63 additions and 54 deletions

View File

@@ -21,16 +21,16 @@ const (
EnvStrict = "NB_DNS_FIREWALL_STRICT"
)
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
var defaultBlockedPorts = []uint16{53, 853}
// strictMode reports whether strict mode is enabled via env.
func strictMode() bool {
v, _ := strconv.ParseBool(os.Getenv(EnvStrict))
return v
}
// defaultBlockedPorts are the well-known DNS ports we block for non-netbird
// processes: 53 (plain DNS) and 853 (DNS-over-TLS).
var defaultBlockedPorts = []uint16{53, 853}
// blockedPorts returns the effective port list, honoring env overrides.
// A nil return means the firewall should not be installed.
func blockedPorts() []uint16 {

View File

@@ -4,12 +4,12 @@ package dnsfw
import "net/netip"
// New returns a no-op manager on non-Windows platforms.
func New() Manager {
return noopManager{}
}
type noopManager struct{}
func (noopManager) Enable(string, netip.Addr) error { return nil }
func (noopManager) Disable() error { return nil }
// New returns a no-op manager on non-Windows platforms.
func New() Manager {
return noopManager{}
}

View File

@@ -13,10 +13,10 @@ import (
"golang.org/x/sys/windows"
)
// New returns a Windows DNS firewall manager backed by WFP.
func New() Manager {
return &windowsManager{}
}
var (
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
)
type windowsManager struct {
mu sync.Mutex
@@ -80,6 +80,24 @@ func (m *windowsManager) Enable(ifaceGUID string, virtualDNSIP netip.Addr) error
return nil
}
func (m *windowsManager) Disable() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.disableLocked()
}
func (m *windowsManager) disableLocked() error {
if m.session == 0 {
return nil
}
if err := closeSession(m.session); err != nil {
return fmt.Errorf("close wfp session: %w", err)
}
m.session = 0
log.Info("dns firewall removed")
return nil
}
// failOrLog returns err unchanged in strict mode. In non-strict mode the
// error is logged and nil is returned.
func (m *windowsManager) failOrLog(strict bool, err error) error {
@@ -90,6 +108,11 @@ func (m *windowsManager) failOrLog(strict bool, err error) error {
return nil
}
// New returns a Windows DNS firewall manager backed by WFP.
func New() Manager {
return &windowsManager{}
}
// luidFromGUID converts a Windows interface GUID string to its LUID.
func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
defer func() {
@@ -111,26 +134,3 @@ func luidFromGUID(ifaceGUID string) (luid uint64, err error) {
}
return luid, nil
}
var (
modIphlpapi = windows.NewLazyDLL("iphlpapi.dll")
procConvertInterfaceGuidToLuid = modIphlpapi.NewProc("ConvertInterfaceGuidToLuid")
)
func (m *windowsManager) Disable() error {
m.mu.Lock()
defer m.mu.Unlock()
return m.disableLocked()
}
func (m *windowsManager) disableLocked() error {
if m.session == 0 {
return nil
}
if err := closeSession(m.session); err != nil {
return fmt.Errorf("close wfp session: %w", err)
}
m.session = 0
log.Info("dns firewall removed")
return nil
}

View File

@@ -172,24 +172,8 @@ func (r *registryConfigurator) disableWINSForInterface() error {
}
func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *statemanager.Manager) error {
if config.RouteAll {
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
return fmt.Errorf("dns firewall: %w", err)
}
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
} else {
if err := r.dnsFirewall.Disable(); err != nil {
log.Errorf("disable dns firewall: %v", err)
}
if r.routingAll {
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
return fmt.Errorf("delete interface registry key property: %w", err)
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
}
if err := r.applyRouteAll(config); err != nil {
return err
}
r.updateState(stateManager)
@@ -231,6 +215,31 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
return nil
}
func (r *registryConfigurator) applyRouteAll(config HostDNSConfig) error {
if config.RouteAll {
if err := r.dnsFirewall.Enable(r.guid, config.ServerIP); err != nil {
return fmt.Errorf("dns firewall: %w", err)
}
if err := r.addDNSSetupForAll(config.ServerIP); err != nil {
return fmt.Errorf("add dns setup: %w", err)
}
return nil
}
if err := r.dnsFirewall.Disable(); err != nil {
log.Errorf("disable dns firewall: %v", err)
}
if !r.routingAll {
return nil
}
if err := r.deleteInterfaceRegistryKeyProperty(interfaceConfigNameServerKey); err != nil {
return fmt.Errorf("delete interface registry key property: %w", err)
}
r.routingAll = false
log.Infof("removed %s as main DNS forwarder for this peer", config.ServerIP)
return nil
}
func (r *registryConfigurator) updateState(stateManager *statemanager.Manager) {
if err := stateManager.UpdateState(&ShutdownState{
Guid: r.guid,