From 430f2bf7fa381552e1ef8e02beddd1e4649fc15d Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 24 Nov 2025 14:56:24 -0500 Subject: [PATCH] Reorg working windows Former-commit-id: ec5d1ef1d12584f7cc68d10cc4dea6d28e26d9c1 --- dns/platform/detect_windows.go | 12 ++--- dns/platform/windows.go | 82 +++++++++++++++++++++++++++++++++- olm/dns_override_windows.go | 16 ++----- olm/interface_guid_stub.go | 15 ------- olm/interface_guid_windows.go | 69 ---------------------------- 5 files changed, 90 insertions(+), 104 deletions(-) delete mode 100644 olm/interface_guid_stub.go delete mode 100644 olm/interface_guid_windows.go diff --git a/dns/platform/detect_windows.go b/dns/platform/detect_windows.go index 81576f4..d62cc94 100644 --- a/dns/platform/detect_windows.go +++ b/dns/platform/detect_windows.go @@ -5,17 +5,17 @@ package dns import "fmt" // DetectBestConfigurator returns the Windows DNS configurator -// guid is the network interface GUID -func DetectBestConfigurator(guid string) (DNSConfigurator, error) { - if guid == "" { +// ifaceName should be the network interface GUID on Windows +func DetectBestConfigurator(ifaceName string) (DNSConfigurator, error) { + if ifaceName == "" { return nil, fmt.Errorf("interface GUID is required for Windows") } - return NewWindowsDNSConfigurator(guid) + return newWindowsDNSConfiguratorFromGUID(ifaceName) } // GetSystemDNS returns the current system DNS servers for the given interface -func GetSystemDNS(guid string) ([]string, error) { - configurator, err := NewWindowsDNSConfigurator(guid) +func GetSystemDNS(ifaceName string) ([]string, error) { + configurator, err := newWindowsDNSConfiguratorFromGUID(ifaceName) if err != nil { return nil, fmt.Errorf("create configurator: %w", err) } diff --git a/dns/platform/windows.go b/dns/platform/windows.go index c5f3f21..52d6953 100644 --- a/dns/platform/windows.go +++ b/dns/platform/windows.go @@ -8,8 +8,11 @@ import ( "io" "net/netip" "syscall" + "unsafe" + "golang.org/x/sys/windows" "golang.org/x/sys/windows/registry" + "golang.zx2c4.com/wireguard/tun" ) var ( @@ -30,8 +33,25 @@ type WindowsDNSConfigurator struct { } // NewWindowsDNSConfigurator creates a new Windows DNS configurator -// guid is the network interface GUID -func NewWindowsDNSConfigurator(guid string) (*WindowsDNSConfigurator, error) { +// Accepts a TUN device and extracts the GUID internally +func NewWindowsDNSConfigurator(tunDevice tun.Device) (*WindowsDNSConfigurator, error) { + if tunDevice == nil { + return nil, fmt.Errorf("TUN device is required") + } + + guid, err := getInterfaceGUIDString(tunDevice) + if err != nil { + return nil, fmt.Errorf("failed to get interface GUID: %w", err) + } + + return &WindowsDNSConfigurator{ + guid: guid, + }, nil +} + +// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string +// This is an internal function for use by DetectBestConfigurator +func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) { if guid == "" { return nil, fmt.Errorf("interface GUID is required") } @@ -245,3 +265,61 @@ func closeKey(closer io.Closer) { fmt.Printf("warning: failed to close registry key: %v\n", err) } } + +// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface +// This is required for registry-based DNS configuration on Windows +func getInterfaceGUIDString(tunDevice tun.Device) (string, error) { + if tunDevice == nil { + return "", fmt.Errorf("TUN device is nil") + } + + // The wireguard-go Windows TUN device has a LUID() method + // We need to use type assertion to access it + type nativeTun interface { + LUID() uint64 + } + + nativeDev, ok := tunDevice.(nativeTun) + if !ok { + return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)") + } + + luid := nativeDev.LUID() + + // Convert LUID to GUID using Windows API + guid, err := luidToGUID(luid) + if err != nil { + return "", fmt.Errorf("failed to convert LUID to GUID: %w", err) + } + + return guid, nil +} + +// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string +// using the Windows ConvertInterface* APIs +func luidToGUID(luid uint64) (string, error) { + var guid windows.GUID + + // Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function + iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll") + convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid") + + // Call the Windows API + // NET_LUID is a 64-bit value on Windows + ret, _, err := convertLuidToGuid.Call( + uintptr(unsafe.Pointer(&luid)), + uintptr(unsafe.Pointer(&guid)), + ) + + if ret != 0 { + return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err) + } + + // Format the GUID as a string with curly braces + guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}", + guid.Data1, guid.Data2, guid.Data3, + guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3], + guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7]) + + return guidStr, nil +} diff --git a/olm/dns_override_windows.go b/olm/dns_override_windows.go index 842723a..7de9cc9 100644 --- a/olm/dns_override_windows.go +++ b/olm/dns_override_windows.go @@ -12,31 +12,23 @@ import ( ) // SetupDNSOverride configures the system DNS to use the DNS proxy on Windows -// Uses registry-based configuration (requires interface GUID) +// Uses registry-based configuration (automatically extracts interface GUID) func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error { if dnsProxy == nil { return fmt.Errorf("DNS proxy is nil") } - // On Windows, we need to get the interface GUID from the TUN device - // The interfaceName parameter is ignored on Windows if tdev == nil { return fmt.Errorf("TUN device is not available") } - guid, err := GetInterfaceGUIDString(tdev) - if err != nil { - return fmt.Errorf("failed to get interface GUID: %w", err) - } - - logger.Info("Retrieved interface GUID: %s for interface name: %s", guid, interfaceName) - - configurator, err = platform.NewWindowsDNSConfigurator(guid) + var err error + configurator, err = platform.NewWindowsDNSConfigurator(tdev) if err != nil { return fmt.Errorf("failed to create Windows DNS configurator: %w", err) } - logger.Info("Using Windows registry DNS configurator for GUID: %s", guid) + logger.Info("Using Windows registry DNS configurator for interface: %s", interfaceName) // Get current DNS servers before changing currentDNS, err := configurator.GetCurrentDNS() diff --git a/olm/interface_guid_stub.go b/olm/interface_guid_stub.go deleted file mode 100644 index cf0ad6a..0000000 --- a/olm/interface_guid_stub.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build !windows - -package olm - -import ( - "fmt" - - "golang.zx2c4.com/wireguard/tun" -) - -// GetInterfaceGUIDString is only implemented for Windows -// This stub is provided for compilation on other platforms -func GetInterfaceGUIDString(tunDevice tun.Device) (string, error) { - return "", fmt.Errorf("GetInterfaceGUIDString is only supported on Windows") -} diff --git a/olm/interface_guid_windows.go b/olm/interface_guid_windows.go deleted file mode 100644 index 64ba91d..0000000 --- a/olm/interface_guid_windows.go +++ /dev/null @@ -1,69 +0,0 @@ -//go:build windows - -package olm - -import ( - "fmt" - "unsafe" - - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/tun" -) - -// GetInterfaceGUIDString retrieves the GUID string for a Windows TUN interface -// This is required for registry-based DNS configuration on Windows -func GetInterfaceGUIDString(tunDevice tun.Device) (string, error) { - if tunDevice == nil { - return "", fmt.Errorf("TUN device is nil") - } - - // The wireguard-go Windows TUN device has a LUID() method - // We need to use type assertion to access it - type nativeTun interface { - LUID() uint64 - } - - nativeDev, ok := tunDevice.(nativeTun) - if !ok { - return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)") - } - - luid := nativeDev.LUID() - - // Convert LUID to GUID using Windows API - guid, err := luidToGUID(luid) - if err != nil { - return "", fmt.Errorf("failed to convert LUID to GUID: %w", err) - } - - return guid, nil -} - -// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string -// using the Windows ConvertInterface* APIs -func luidToGUID(luid uint64) (string, error) { - var guid windows.GUID - - // Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function - iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll") - convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid") - - // Call the Windows API - // NET_LUID is a 64-bit value on Windows - ret, _, err := convertLuidToGuid.Call( - uintptr(unsafe.Pointer(&luid)), - uintptr(unsafe.Pointer(&guid)), - ) - - if ret != 0 { - return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err) - } - - // Format the GUID as a string with curly braces - guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}", - guid.Data1, guid.Data2, guid.Data3, - guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3], - guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7]) - - return guidStr, nil -}