mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
@@ -6,13 +6,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -33,13 +33,13 @@ type WindowsDNSConfigurator struct {
|
||||
}
|
||||
|
||||
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
||||
// Accepts a TUN device and extracts the GUID internally
|
||||
func NewWindowsDNSConfigurator(tunDevice tun.Device) (*WindowsDNSConfigurator, error) {
|
||||
if tunDevice == nil {
|
||||
return nil, fmt.Errorf("TUN device is required")
|
||||
// Accepts an interface name and extracts the GUID internally
|
||||
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
|
||||
if interfaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
guid, err := getInterfaceGUIDString(tunDevice)
|
||||
guid, err := getInterfaceGUIDString(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
||||
}
|
||||
@@ -268,24 +268,21 @@ func closeKey(closer io.Closer) {
|
||||
|
||||
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
||||
// This is required for registry-based DNS configuration on Windows
|
||||
func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
||||
if tunDevice == nil {
|
||||
return "", fmt.Errorf("TUN device is nil")
|
||||
func getInterfaceGUIDString(interfaceName string) (string, error) {
|
||||
if interfaceName == "" {
|
||||
return "", fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// The wireguard-go Windows TUN device has a LUID() method
|
||||
// We need to use type assertion to access it
|
||||
type nativeTun interface {
|
||||
LUID() uint64
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
|
||||
}
|
||||
|
||||
nativeDev, ok := tunDevice.(nativeTun)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)")
|
||||
luid, err := indexToLUID(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
|
||||
}
|
||||
|
||||
luid := nativeDev.LUID()
|
||||
|
||||
// Convert LUID to GUID using Windows API
|
||||
guid, err := luidToGUID(luid)
|
||||
if err != nil {
|
||||
@@ -295,6 +292,27 @@ func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
||||
return guid, nil
|
||||
}
|
||||
|
||||
// indexToLUID converts a Windows interface index to a LUID
|
||||
func indexToLUID(index uint32) (uint64, error) {
|
||||
var luid uint64
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||
|
||||
// Call the Windows API
|
||||
ret, _, err := convertInterfaceIndexToLuid.Call(
|
||||
uintptr(index),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
return luid, nil
|
||||
}
|
||||
|
||||
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
||||
// using the Windows ConvertInterface* APIs
|
||||
func luidToGUID(luid uint64) (string, error) {
|
||||
|
||||
Reference in New Issue
Block a user