mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
@@ -11,6 +11,8 @@ import (
|
|||||||
platform "github.com/fosrl/olm/dns/platform"
|
platform "github.com/fosrl/olm/dns/platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var configurator platform.DNSConfigurator
|
||||||
|
|
||||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
||||||
// Uses scutil for DNS configuration
|
// Uses scutil for DNS configuration
|
||||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
platform "github.com/fosrl/olm/dns/platform"
|
platform "github.com/fosrl/olm/dns/platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var configurator platform.DNSConfigurator
|
||||||
|
|
||||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
||||||
// Tries systemd-resolved, NetworkManager, resolvconf, or falls back to /etc/resolv.conf
|
// Tries systemd-resolved, NetworkManager, resolvconf, or falls back to /etc/resolv.conf
|
||||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
platform "github.com/fosrl/olm/dns/platform"
|
platform "github.com/fosrl/olm/dns/platform"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var configurator platform.DNSConfigurator
|
||||||
|
|
||||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||||
// Uses registry-based configuration (automatically extracts interface GUID)
|
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||||
@@ -18,12 +20,8 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
|||||||
return fmt.Errorf("DNS proxy is nil")
|
return fmt.Errorf("DNS proxy is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
if tdev == nil {
|
|
||||||
return fmt.Errorf("TUN device is not available")
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
configurator, err = platform.NewWindowsDNSConfigurator(tdev)
|
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
||||||
}
|
}
|
||||||
@@ -6,13 +6,13 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"syscall"
|
"syscall"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.org/x/sys/windows/registry"
|
"golang.org/x/sys/windows/registry"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -33,13 +33,13 @@ type WindowsDNSConfigurator struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
||||||
// Accepts a TUN device and extracts the GUID internally
|
// Accepts an interface name and extracts the GUID internally
|
||||||
func NewWindowsDNSConfigurator(tunDevice tun.Device) (*WindowsDNSConfigurator, error) {
|
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
|
||||||
if tunDevice == nil {
|
if interfaceName == "" {
|
||||||
return nil, fmt.Errorf("TUN device is required")
|
return nil, fmt.Errorf("interface name is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
guid, err := getInterfaceGUIDString(tunDevice)
|
guid, err := getInterfaceGUIDString(interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
||||||
}
|
}
|
||||||
@@ -268,24 +268,21 @@ func closeKey(closer io.Closer) {
|
|||||||
|
|
||||||
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
||||||
// This is required for registry-based DNS configuration on Windows
|
// This is required for registry-based DNS configuration on Windows
|
||||||
func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
func getInterfaceGUIDString(interfaceName string) (string, error) {
|
||||||
if tunDevice == nil {
|
if interfaceName == "" {
|
||||||
return "", fmt.Errorf("TUN device is nil")
|
return "", fmt.Errorf("interface name is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
// The wireguard-go Windows TUN device has a LUID() method
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
// We need to use type assertion to access it
|
if err != nil {
|
||||||
type nativeTun interface {
|
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
|
||||||
LUID() uint64
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nativeDev, ok := tunDevice.(nativeTun)
|
luid, err := indexToLUID(uint32(iface.Index))
|
||||||
if !ok {
|
if err != nil {
|
||||||
return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)")
|
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
luid := nativeDev.LUID()
|
|
||||||
|
|
||||||
// Convert LUID to GUID using Windows API
|
// Convert LUID to GUID using Windows API
|
||||||
guid, err := luidToGUID(luid)
|
guid, err := luidToGUID(luid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -295,6 +292,27 @@ func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
|||||||
return guid, nil
|
return guid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// indexToLUID converts a Windows interface index to a LUID
|
||||||
|
func indexToLUID(index uint32) (uint64, error) {
|
||||||
|
var luid uint64
|
||||||
|
|
||||||
|
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
|
||||||
|
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||||
|
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||||
|
|
||||||
|
// Call the Windows API
|
||||||
|
ret, _, err := convertInterfaceIndexToLuid.Call(
|
||||||
|
uintptr(index),
|
||||||
|
uintptr(unsafe.Pointer(&luid)),
|
||||||
|
)
|
||||||
|
|
||||||
|
if ret != 0 {
|
||||||
|
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return luid, nil
|
||||||
|
}
|
||||||
|
|
||||||
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
||||||
// using the Windows ConvertInterface* APIs
|
// using the Windows ConvertInterface* APIs
|
||||||
func luidToGUID(luid uint64) (string, error) {
|
func luidToGUID(luid uint64) (string, error) {
|
||||||
|
|||||||
22
olm/olm.go
22
olm/olm.go
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -17,7 +16,7 @@ import (
|
|||||||
"github.com/fosrl/olm/api"
|
"github.com/fosrl/olm/api"
|
||||||
middleDevice "github.com/fosrl/olm/device"
|
middleDevice "github.com/fosrl/olm/device"
|
||||||
"github.com/fosrl/olm/dns"
|
"github.com/fosrl/olm/dns"
|
||||||
platform "github.com/fosrl/olm/dns/platform"
|
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||||
"github.com/fosrl/olm/network"
|
"github.com/fosrl/olm/network"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
@@ -93,7 +92,6 @@ var (
|
|||||||
globalCtx context.Context
|
globalCtx context.Context
|
||||||
stopRegister func()
|
stopRegister func()
|
||||||
stopPing chan struct{}
|
stopPing chan struct{}
|
||||||
configurator platform.DNSConfigurator
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func Init(ctx context.Context, config GlobalConfig) {
|
func Init(ctx context.Context, config GlobalConfig) {
|
||||||
@@ -577,7 +575,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
peerMonitor.Start()
|
peerMonitor.Start()
|
||||||
|
|
||||||
// Set up DNS override to use our DNS proxy
|
// Set up DNS override to use our DNS proxy
|
||||||
if err := SetupDNSOverride(interfaceName, dnsProxy); err != nil {
|
if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil {
|
||||||
logger.Error("Failed to setup DNS override: %v", err)
|
logger.Error("Failed to setup DNS override: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -1122,13 +1120,13 @@ func Close() {
|
|||||||
middleDev = nil
|
middleDev = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Restore original DNS
|
// // Restore original DNS
|
||||||
if configurator != nil {
|
// if configurator != nil {
|
||||||
fmt.Println("Restoring original DNS servers...")
|
// fmt.Println("Restoring original DNS servers...")
|
||||||
if err := configurator.RestoreDNS(); err != nil {
|
// if err := configurator.RestoreDNS(); err != nil {
|
||||||
log.Fatalf("Failed to restore DNS: %v", err)
|
// log.Fatalf("Failed to restore DNS: %v", err)
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Stop DNS proxy
|
// Stop DNS proxy
|
||||||
logger.Debug("Stopping DNS proxy")
|
logger.Debug("Stopping DNS proxy")
|
||||||
@@ -1177,7 +1175,7 @@ func StopTunnel() {
|
|||||||
Close()
|
Close()
|
||||||
|
|
||||||
// Restore original DNS configuration
|
// Restore original DNS configuration
|
||||||
if err := RestoreDNSOverride(); err != nil {
|
if err := dnsOverride.RestoreDNSOverride(); err != nil {
|
||||||
logger.Error("Failed to restore DNS: %v", err)
|
logger.Error("Failed to restore DNS: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user