Files
olm/dns/platform/windows.go
2025-12-18 11:33:59 -05:00

356 lines
9.7 KiB
Go

//go:build windows
package dns
import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
var (
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
)
const (
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
interfaceConfigNameServer = "NameServer"
interfaceConfigDhcpNameServer = "DhcpNameServer"
)
// WindowsDNSConfigurator manages DNS settings on Windows using the registry
type WindowsDNSConfigurator struct {
guid string
originalState *DNSState
}
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
// Accepts an interface name and extracts the GUID internally
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
if interfaceName == "" {
return nil, fmt.Errorf("interface name is required")
}
guid, err := getInterfaceGUIDString(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
}
return &WindowsDNSConfigurator{
guid: guid,
}, nil
}
// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string
// This is an internal function for use by DetectBestConfigurator
func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) {
if guid == "" {
return nil, fmt.Errorf("interface GUID is required")
}
return &WindowsDNSConfigurator{
guid: guid,
}, nil
}
// Name returns the configurator name
func (w *WindowsDNSConfigurator) Name() string {
return "windows-registry"
}
// SetDNS sets the DNS servers and returns the original servers
func (w *WindowsDNSConfigurator) SetDNS(servers []netip.Addr) ([]netip.Addr, error) {
// Get current DNS settings before overriding
originalServers, err := w.GetCurrentDNS()
if err != nil {
return nil, fmt.Errorf("get current DNS: %w", err)
}
// Store original state
w.originalState = &DNSState{
OriginalServers: originalServers,
ConfiguratorName: w.Name(),
}
// Set new DNS servers
if err := w.setDNSServers(servers); err != nil {
return nil, fmt.Errorf("set DNS servers: %w", err)
}
// Flush DNS cache
if err := w.flushDNSCache(); err != nil {
// Non-fatal, just log
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return originalServers, nil
}
// RestoreDNS restores the original DNS configuration
func (w *WindowsDNSConfigurator) RestoreDNS() error {
if w.originalState == nil {
return fmt.Errorf("no original state to restore")
}
// Clear the static DNS setting
if err := w.clearDNSServers(); err != nil {
return fmt.Errorf("clear DNS servers: %w", err)
}
// Flush DNS cache
if err := w.flushDNSCache(); err != nil {
fmt.Printf("warning: failed to flush DNS cache: %v\n", err)
}
return nil
}
// CleanupUncleanShutdown removes any DNS configuration left over from a previous crash
// On Windows, we rely on the registry-based approach which doesn't leave orphaned state
// in the same way as macOS scutil. The DNS settings are tied to the interface which
// gets recreated on restart.
func (w *WindowsDNSConfigurator) CleanupUncleanShutdown() error {
// Windows DNS configuration via registry is interface-specific.
// When the WireGuard interface is recreated, it gets a new GUID,
// so there's no leftover state to clean up from previous sessions.
// The old interface's registry keys are effectively orphaned but harmless.
return nil
}
// GetCurrentDNS returns the currently configured DNS servers
func (w *WindowsDNSConfigurator) GetCurrentDNS() ([]netip.Addr, error) {
regKey, err := w.getInterfaceRegistryKey(registry.QUERY_VALUE)
if err != nil {
return nil, fmt.Errorf("get interface registry key: %w", err)
}
defer closeKey(regKey)
// Try to get static DNS first
nameServer, _, err := regKey.GetStringValue(interfaceConfigNameServer)
if err == nil && nameServer != "" {
return w.parseServerList(nameServer), nil
}
// Fall back to DHCP DNS
dhcpNameServer, _, err := regKey.GetStringValue(interfaceConfigDhcpNameServer)
if err == nil && dhcpNameServer != "" {
return w.parseServerList(dhcpNameServer), nil
}
return []netip.Addr{}, nil
}
// setDNSServers sets the DNS servers in the registry
func (w *WindowsDNSConfigurator) setDNSServers(servers []netip.Addr) error {
if len(servers) == 0 {
return fmt.Errorf("no DNS servers provided")
}
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closeKey(regKey)
// Build comma-separated or space-separated list of servers
var serverList string
for i, server := range servers {
if i > 0 {
serverList += ","
}
serverList += server.String()
}
if err := regKey.SetStringValue(interfaceConfigNameServer, serverList); err != nil {
return fmt.Errorf("set NameServer: %w", err)
}
return nil
}
// clearDNSServers clears the static DNS server setting
func (w *WindowsDNSConfigurator) clearDNSServers() error {
regKey, err := w.getInterfaceRegistryKey(registry.SET_VALUE)
if err != nil {
return fmt.Errorf("get interface registry key: %w", err)
}
defer closeKey(regKey)
// Set empty string to revert to DHCP
if err := regKey.SetStringValue(interfaceConfigNameServer, ""); err != nil {
return fmt.Errorf("clear NameServer: %w", err)
}
return nil
}
// getInterfaceRegistryKey opens the registry key for the network interface
func (w *WindowsDNSConfigurator) getInterfaceRegistryKey(access uint32) (registry.Key, error) {
regKeyPath := interfaceConfigPath + `\` + w.guid
regKey, err := registry.OpenKey(registry.LOCAL_MACHINE, regKeyPath, access)
if err != nil {
return 0, fmt.Errorf("open HKEY_LOCAL_MACHINE\\%s: %w", regKeyPath, err)
}
return regKey, nil
}
// parseServerList parses a comma or space-separated list of DNS servers
func (w *WindowsDNSConfigurator) parseServerList(serverList string) []netip.Addr {
var servers []netip.Addr
// Split by comma or space
parts := splitByDelimiters(serverList, []rune{',', ' '})
for _, part := range parts {
if addr, err := netip.ParseAddr(part); err == nil {
servers = append(servers, addr)
}
}
return servers
}
// flushDNSCache flushes the Windows DNS resolver cache
func (w *WindowsDNSConfigurator) flushDNSCache() error {
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
defer func() {
if rec := recover(); rec != nil {
fmt.Printf("warning: DnsFlushResolverCache panicked: %v\n", rec)
}
}()
ret, _, err := dnsFlushResolverCacheFn.Call()
if ret == 0 {
if err != nil && !errors.Is(err, syscall.Errno(0)) {
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
}
return fmt.Errorf("DnsFlushResolverCache failed")
}
return nil
}
// splitByDelimiters splits a string by multiple delimiters
func splitByDelimiters(s string, delimiters []rune) []string {
var result []string
var current []rune
for _, char := range s {
isDelimiter := false
for _, delim := range delimiters {
if char == delim {
isDelimiter = true
break
}
}
if isDelimiter {
if len(current) > 0 {
result = append(result, string(current))
current = []rune{}
}
} else {
current = append(current, char)
}
}
if len(current) > 0 {
result = append(result, string(current))
}
return result
}
// closeKey closes a registry key and logs errors
func closeKey(closer io.Closer) {
if err := closer.Close(); err != nil {
fmt.Printf("warning: failed to close registry key: %v\n", err)
}
}
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
// This is required for registry-based DNS configuration on Windows
func getInterfaceGUIDString(interfaceName string) (string, error) {
if interfaceName == "" {
return "", fmt.Errorf("interface name is required")
}
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
}
luid, err := indexToLUID(uint32(iface.Index))
if err != nil {
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
}
// Convert LUID to GUID using Windows API
guid, err := luidToGUID(luid)
if err != nil {
return "", fmt.Errorf("failed to convert LUID to GUID: %w", err)
}
return guid, nil
}
// indexToLUID converts a Windows interface index to a LUID
func indexToLUID(index uint32) (uint64, error) {
var luid uint64
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
// Call the Windows API
ret, _, err := convertInterfaceIndexToLuid.Call(
uintptr(index),
uintptr(unsafe.Pointer(&luid)),
)
if ret != 0 {
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
}
return luid, nil
}
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
// using the Windows ConvertInterface* APIs
func luidToGUID(luid uint64) (string, error) {
var guid windows.GUID
// Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid")
// Call the Windows API
// NET_LUID is a 64-bit value on Windows
ret, _, err := convertLuidToGuid.Call(
uintptr(unsafe.Pointer(&luid)),
uintptr(unsafe.Pointer(&guid)),
)
if ret != 0 {
return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err)
}
// Format the GUID as a string with curly braces
guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
guid.Data1, guid.Data2, guid.Data3,
guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3],
guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7])
return guidStr, nil
}