mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
@@ -5,17 +5,17 @@ package dns
|
|||||||
import "fmt"
|
import "fmt"
|
||||||
|
|
||||||
// DetectBestConfigurator returns the Windows DNS configurator
|
// DetectBestConfigurator returns the Windows DNS configurator
|
||||||
// guid is the network interface GUID
|
// ifaceName should be the network interface GUID on Windows
|
||||||
func DetectBestConfigurator(guid string) (DNSConfigurator, error) {
|
func DetectBestConfigurator(ifaceName string) (DNSConfigurator, error) {
|
||||||
if guid == "" {
|
if ifaceName == "" {
|
||||||
return nil, fmt.Errorf("interface GUID is required for Windows")
|
return nil, fmt.Errorf("interface GUID is required for Windows")
|
||||||
}
|
}
|
||||||
return NewWindowsDNSConfigurator(guid)
|
return newWindowsDNSConfiguratorFromGUID(ifaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSystemDNS returns the current system DNS servers for the given interface
|
// GetSystemDNS returns the current system DNS servers for the given interface
|
||||||
func GetSystemDNS(guid string) ([]string, error) {
|
func GetSystemDNS(ifaceName string) ([]string, error) {
|
||||||
configurator, err := NewWindowsDNSConfigurator(guid)
|
configurator, err := newWindowsDNSConfiguratorFromGUID(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create configurator: %w", err)
|
return nil, fmt.Errorf("create configurator: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,8 +8,11 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -30,8 +33,25 @@ type WindowsDNSConfigurator struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
||||||
// guid is the network interface GUID
|
// Accepts a TUN device and extracts the GUID internally
|
||||||
func NewWindowsDNSConfigurator(guid string) (*WindowsDNSConfigurator, error) {
|
func NewWindowsDNSConfigurator(tunDevice tun.Device) (*WindowsDNSConfigurator, error) {
|
||||||
|
if tunDevice == nil {
|
||||||
|
return nil, fmt.Errorf("TUN device is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
guid, err := getInterfaceGUIDString(tunDevice)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &WindowsDNSConfigurator{
|
||||||
|
guid: guid,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newWindowsDNSConfiguratorFromGUID creates a configurator from a GUID string
|
||||||
|
// This is an internal function for use by DetectBestConfigurator
|
||||||
|
func newWindowsDNSConfiguratorFromGUID(guid string) (*WindowsDNSConfigurator, error) {
|
||||||
if guid == "" {
|
if guid == "" {
|
||||||
return nil, fmt.Errorf("interface GUID is required")
|
return nil, fmt.Errorf("interface GUID is required")
|
||||||
}
|
}
|
||||||
@@ -245,3 +265,61 @@ func closeKey(closer io.Closer) {
|
|||||||
fmt.Printf("warning: failed to close registry key: %v\n", err)
|
fmt.Printf("warning: failed to close registry key: %v\n", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
||||||
|
// This is required for registry-based DNS configuration on Windows
|
||||||
|
func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
||||||
|
if tunDevice == nil {
|
||||||
|
return "", fmt.Errorf("TUN device is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// The wireguard-go Windows TUN device has a LUID() method
|
||||||
|
// We need to use type assertion to access it
|
||||||
|
type nativeTun interface {
|
||||||
|
LUID() uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
nativeDev, ok := tunDevice.(nativeTun)
|
||||||
|
if !ok {
|
||||||
|
return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)")
|
||||||
|
}
|
||||||
|
|
||||||
|
luid := nativeDev.LUID()
|
||||||
|
|
||||||
|
// Convert LUID to GUID using Windows API
|
||||||
|
guid, err := luidToGUID(luid)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to convert LUID to GUID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return guid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
||||||
|
// using the Windows ConvertInterface* APIs
|
||||||
|
func luidToGUID(luid uint64) (string, error) {
|
||||||
|
var guid windows.GUID
|
||||||
|
|
||||||
|
// Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function
|
||||||
|
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||||
|
convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid")
|
||||||
|
|
||||||
|
// Call the Windows API
|
||||||
|
// NET_LUID is a 64-bit value on Windows
|
||||||
|
ret, _, err := convertLuidToGuid.Call(
|
||||||
|
uintptr(unsafe.Pointer(&luid)),
|
||||||
|
uintptr(unsafe.Pointer(&guid)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret != 0 {
|
||||||
|
return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format the GUID as a string with curly braces
|
||||||
|
guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
|
||||||
|
guid.Data1, guid.Data2, guid.Data3,
|
||||||
|
guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3],
|
||||||
|
guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7])
|
||||||
|
|
||||||
|
return guidStr, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,31 +12,23 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||||
// Uses registry-based configuration (requires interface GUID)
|
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||||
if dnsProxy == nil {
|
if dnsProxy == nil {
|
||||||
return fmt.Errorf("DNS proxy is nil")
|
return fmt.Errorf("DNS proxy is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
// On Windows, we need to get the interface GUID from the TUN device
|
|
||||||
// The interfaceName parameter is ignored on Windows
|
|
||||||
if tdev == nil {
|
if tdev == nil {
|
||||||
return fmt.Errorf("TUN device is not available")
|
return fmt.Errorf("TUN device is not available")
|
||||||
}
|
}
|
||||||
|
|
||||||
guid, err := GetInterfaceGUIDString(tdev)
|
var err error
|
||||||
if err != nil {
|
configurator, err = platform.NewWindowsDNSConfigurator(tdev)
|
||||||
return fmt.Errorf("failed to get interface GUID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Retrieved interface GUID: %s for interface name: %s", guid, interfaceName)
|
|
||||||
|
|
||||||
configurator, err = platform.NewWindowsDNSConfigurator(guid)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Using Windows registry DNS configurator for GUID: %s", guid)
|
logger.Info("Using Windows registry DNS configurator for interface: %s", interfaceName)
|
||||||
|
|
||||||
// Get current DNS servers before changing
|
// Get current DNS servers before changing
|
||||||
currentDNS, err := configurator.GetCurrentDNS()
|
currentDNS, err := configurator.GetCurrentDNS()
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package olm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetInterfaceGUIDString is only implemented for Windows
|
|
||||||
// This stub is provided for compilation on other platforms
|
|
||||||
func GetInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
|
||||||
return "", fmt.Errorf("GetInterfaceGUIDString is only supported on Windows")
|
|
||||||
}
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
//go:build windows
|
|
||||||
|
|
||||||
package olm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GetInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
|
||||||
// This is required for registry-based DNS configuration on Windows
|
|
||||||
func GetInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
|
||||||
if tunDevice == nil {
|
|
||||||
return "", fmt.Errorf("TUN device is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// The wireguard-go Windows TUN device has a LUID() method
|
|
||||||
// We need to use type assertion to access it
|
|
||||||
type nativeTun interface {
|
|
||||||
LUID() uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
nativeDev, ok := tunDevice.(nativeTun)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)")
|
|
||||||
}
|
|
||||||
|
|
||||||
luid := nativeDev.LUID()
|
|
||||||
|
|
||||||
// Convert LUID to GUID using Windows API
|
|
||||||
guid, err := luidToGUID(luid)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to convert LUID to GUID: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return guid, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
|
||||||
// using the Windows ConvertInterface* APIs
|
|
||||||
func luidToGUID(luid uint64) (string, error) {
|
|
||||||
var guid windows.GUID
|
|
||||||
|
|
||||||
// Load the iphlpapi.dll and get the ConvertInterfaceLuidToGuid function
|
|
||||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
|
||||||
convertLuidToGuid := iphlpapi.NewProc("ConvertInterfaceLuidToGuid")
|
|
||||||
|
|
||||||
// Call the Windows API
|
|
||||||
// NET_LUID is a 64-bit value on Windows
|
|
||||||
ret, _, err := convertLuidToGuid.Call(
|
|
||||||
uintptr(unsafe.Pointer(&luid)),
|
|
||||||
uintptr(unsafe.Pointer(&guid)),
|
|
||||||
)
|
|
||||||
|
|
||||||
if ret != 0 {
|
|
||||||
return "", fmt.Errorf("ConvertInterfaceLuidToGuid failed with code %d: %w", ret, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Format the GUID as a string with curly braces
|
|
||||||
guidStr := fmt.Sprintf("{%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X}",
|
|
||||||
guid.Data1, guid.Data2, guid.Data3,
|
|
||||||
guid.Data4[0], guid.Data4[1], guid.Data4[2], guid.Data4[3],
|
|
||||||
guid.Data4[4], guid.Data4[5], guid.Data4[6], guid.Data4[7])
|
|
||||||
|
|
||||||
return guidStr, nil
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user