mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Refactor
This commit is contained in:
@@ -11,6 +11,8 @@ import (
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on macOS
|
||||
// Uses scutil for DNS configuration
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Linux/FreeBSD
|
||||
// Tries systemd-resolved, NetworkManager, resolvconf, or falls back to /etc/resolv.conf
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
)
|
||||
|
||||
var configurator platform.DNSConfigurator
|
||||
|
||||
// SetupDNSOverride configures the system DNS to use the DNS proxy on Windows
|
||||
// Uses registry-based configuration (automatically extracts interface GUID)
|
||||
func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
@@ -18,12 +20,8 @@ func SetupDNSOverride(interfaceName string, dnsProxy *dns.DNSProxy) error {
|
||||
return fmt.Errorf("DNS proxy is nil")
|
||||
}
|
||||
|
||||
if tdev == nil {
|
||||
return fmt.Errorf("TUN device is not available")
|
||||
}
|
||||
|
||||
var err error
|
||||
configurator, err = platform.NewWindowsDNSConfigurator(tdev)
|
||||
configurator, err = platform.NewWindowsDNSConfigurator(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Windows DNS configurator: %w", err)
|
||||
}
|
||||
@@ -6,13 +6,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -33,13 +33,13 @@ type WindowsDNSConfigurator struct {
|
||||
}
|
||||
|
||||
// NewWindowsDNSConfigurator creates a new Windows DNS configurator
|
||||
// Accepts a TUN device and extracts the GUID internally
|
||||
func NewWindowsDNSConfigurator(tunDevice tun.Device) (*WindowsDNSConfigurator, error) {
|
||||
if tunDevice == nil {
|
||||
return nil, fmt.Errorf("TUN device is required")
|
||||
// Accepts an interface name and extracts the GUID internally
|
||||
func NewWindowsDNSConfigurator(interfaceName string) (*WindowsDNSConfigurator, error) {
|
||||
if interfaceName == "" {
|
||||
return nil, fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
guid, err := getInterfaceGUIDString(tunDevice)
|
||||
guid, err := getInterfaceGUIDString(interfaceName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get interface GUID: %w", err)
|
||||
}
|
||||
@@ -268,24 +268,21 @@ func closeKey(closer io.Closer) {
|
||||
|
||||
// getInterfaceGUIDString retrieves the GUID string for a Windows TUN interface
|
||||
// This is required for registry-based DNS configuration on Windows
|
||||
func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
||||
if tunDevice == nil {
|
||||
return "", fmt.Errorf("TUN device is nil")
|
||||
func getInterfaceGUIDString(interfaceName string) (string, error) {
|
||||
if interfaceName == "" {
|
||||
return "", fmt.Errorf("interface name is required")
|
||||
}
|
||||
|
||||
// The wireguard-go Windows TUN device has a LUID() method
|
||||
// We need to use type assertion to access it
|
||||
type nativeTun interface {
|
||||
LUID() uint64
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get interface %s: %w", interfaceName, err)
|
||||
}
|
||||
|
||||
nativeDev, ok := tunDevice.(nativeTun)
|
||||
if !ok {
|
||||
return "", fmt.Errorf("TUN device does not support LUID retrieval (not a native Windows TUN device)")
|
||||
luid, err := indexToLUID(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to convert index to LUID: %w", err)
|
||||
}
|
||||
|
||||
luid := nativeDev.LUID()
|
||||
|
||||
// Convert LUID to GUID using Windows API
|
||||
guid, err := luidToGUID(luid)
|
||||
if err != nil {
|
||||
@@ -295,6 +292,27 @@ func getInterfaceGUIDString(tunDevice tun.Device) (string, error) {
|
||||
return guid, nil
|
||||
}
|
||||
|
||||
// indexToLUID converts a Windows interface index to a LUID
|
||||
func indexToLUID(index uint32) (uint64, error) {
|
||||
var luid uint64
|
||||
|
||||
// Load the iphlpapi.dll and get the ConvertInterfaceIndexToLuid function
|
||||
iphlpapi := windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
convertInterfaceIndexToLuid := iphlpapi.NewProc("ConvertInterfaceIndexToLuid")
|
||||
|
||||
// Call the Windows API
|
||||
ret, _, err := convertInterfaceIndexToLuid.Call(
|
||||
uintptr(index),
|
||||
uintptr(unsafe.Pointer(&luid)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
return 0, fmt.Errorf("ConvertInterfaceIndexToLuid failed with code %d: %w", ret, err)
|
||||
}
|
||||
|
||||
return luid, nil
|
||||
}
|
||||
|
||||
// luidToGUID converts a Windows LUID (Locally Unique Identifier) to a GUID string
|
||||
// using the Windows ConvertInterface* APIs
|
||||
func luidToGUID(luid uint64) (string, error) {
|
||||
|
||||
22
olm/olm.go
22
olm/olm.go
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
@@ -17,7 +16,7 @@ import (
|
||||
"github.com/fosrl/olm/api"
|
||||
middleDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
platform "github.com/fosrl/olm/dns/platform"
|
||||
dnsOverride "github.com/fosrl/olm/dns/override"
|
||||
"github.com/fosrl/olm/network"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
@@ -93,7 +92,6 @@ var (
|
||||
globalCtx context.Context
|
||||
stopRegister func()
|
||||
stopPing chan struct{}
|
||||
configurator platform.DNSConfigurator
|
||||
)
|
||||
|
||||
func Init(ctx context.Context, config GlobalConfig) {
|
||||
@@ -577,7 +575,7 @@ func StartTunnel(config TunnelConfig) {
|
||||
peerMonitor.Start()
|
||||
|
||||
// Set up DNS override to use our DNS proxy
|
||||
if err := SetupDNSOverride(interfaceName, dnsProxy); err != nil {
|
||||
if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil {
|
||||
logger.Error("Failed to setup DNS override: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -1122,13 +1120,13 @@ func Close() {
|
||||
middleDev = nil
|
||||
}
|
||||
|
||||
// Restore original DNS
|
||||
if configurator != nil {
|
||||
fmt.Println("Restoring original DNS servers...")
|
||||
if err := configurator.RestoreDNS(); err != nil {
|
||||
log.Fatalf("Failed to restore DNS: %v", err)
|
||||
}
|
||||
}
|
||||
// // Restore original DNS
|
||||
// if configurator != nil {
|
||||
// fmt.Println("Restoring original DNS servers...")
|
||||
// if err := configurator.RestoreDNS(); err != nil {
|
||||
// log.Fatalf("Failed to restore DNS: %v", err)
|
||||
// }
|
||||
// }
|
||||
|
||||
// Stop DNS proxy
|
||||
logger.Debug("Stopping DNS proxy")
|
||||
@@ -1177,7 +1175,7 @@ func StopTunnel() {
|
||||
Close()
|
||||
|
||||
// Restore original DNS configuration
|
||||
if err := RestoreDNSOverride(); err != nil {
|
||||
if err := dnsOverride.RestoreDNSOverride(); err != nil {
|
||||
logger.Error("Failed to restore DNS: %v", err)
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user