Former-commit-id: 7ae705b1f1
This commit is contained in:
Owen
2025-11-24 16:16:52 -05:00
parent d54b7e3f14
commit 0802673048
5 changed files with 53 additions and 35 deletions

View File

@@ -6,13 +6,13 @@ import (
"errors"
"fmt"
"io"
"net"
"net/netip"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/tun"
)
var (
@@ -33,13 +33,13 @@ type WindowsDNSConfigurator struct {
}
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
// Accepts a TUN device and extracts the GUID internally
func NewWindowsDNSConfigurator(tunDevice tun.Device) (*WindowsDNSConfigurator, error) {
if tunDevice == nil {
return nil, fmt.Errorf("TUN device is required")
// Accepts an interface name and extracts the GUID internally
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
if interfaceName == "" {
return nil, fmt.Errorf("interface name is required")
}
guid, err := getInterfaceGUIDString(tunDevice)
guid, err := getInterfaceGUIDString(interfaceName)
if err != nil {
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
}
@@ -268,24 +268,21 @@ func closeKey(closer io.Closer) {
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
// This is required for registry-based DNS configuration on Windows
func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
if tunDevice == nil {
return "", fmt.Errorf("TUN device is nil")
func getInterfaceGUIDString(interfaceName string) (string, error) {
if interfaceName == "" {
return "", fmt.Errorf("interface name is required")
}
// The wireguard-go Windows TUN device has a LUID() method
// We need to use type assertion to access it
type nativeTun interface {
LUID() uint64
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
}
nativeDev, ok := tunDevice.(nativeTun)
if !ok {
return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)")
luid, err := indexToLUID(uint32(iface.Index))
if err != nil {
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
}
luid := nativeDev.LUID()
// Convert LUID to GUID using Windows API
guid, err := luidToGUID(luid)
if err != nil {
@@ -295,6 +292,27 @@ func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
return guid, nil
}
// indexToLUID converts a Windows interface index to a LUID
func indexToLUID(index uint32) (uint64, error) {
var luid uint64
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
// Call the Windows API
ret, _, err := convertInterfaceIndexToLuid.Call(
uintptr(index),
uintptr(unsafe.Pointer(&luid)),
)
if ret != 0 {
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
}
return luid, nil
}
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
// using the Windows ConvertInterface* APIs
func luidToGUID(luid uint64) (string, error) {