From e238ee4d69b92f0c38d5519f8ceb8ee94345790d Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 24 Nov 2025 16:05:51 -0500 Subject: [PATCH] Convert windows working not using netsh route --- go.mod | 3 +- go.sum | 2 + olm/interface.go | 42 ---------- olm/interface_notwindows.go | 12 +++ olm/interface_windows.go | 60 +++++++++++++++ olm/route.go | 101 ------------------------ olm/route_notwindows.go | 11 +++ olm/route_windows.go | 148 ++++++++++++++++++++++++++++++++++++ 8 files changed, 235 insertions(+), 144 deletions(-) create mode 100644 olm/interface_notwindows.go create mode 100644 olm/interface_windows.go create mode 100644 olm/route_notwindows.go create mode 100644 olm/route_windows.go diff --git a/go.mod b/go.mod index 586f5e7..56b057c 100644 --- a/go.mod +++ b/go.mod @@ -5,18 +5,19 @@ go 1.25 require ( github.com/Microsoft/go-winio v0.6.2 github.com/fosrl/newt v0.0.0 + github.com/godbus/dbus/v5 v5.2.0 github.com/gorilla/websocket v1.5.3 github.com/miekg/dns v1.1.68 github.com/vishvananda/netlink v1.3.1 golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 + golang.zx2c4.com/wireguard/windows v0.5.3 gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c software.sslmate.com/src/go-pkcs12 v0.6.0 ) require ( - github.com/godbus/dbus/v5 v5.2.0 // indirect github.com/google/btree v1.1.3 // indirect github.com/vishvananda/netns v0.0.5 // indirect golang.org/x/crypto v0.44.0 // indirect diff --git a/go.sum b/go.sum index 275773c..addfffc 100644 --- a/go.sum +++ b/go.sum @@ -38,6 +38,8 @@ golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+Z golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= +golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= +golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI= gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g= software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU= diff --git a/olm/interface.go b/olm/interface.go index ae3f252..622382d 100644 --- a/olm/interface.go +++ b/olm/interface.go @@ -51,48 +51,6 @@ func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error { } } -func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - logger.Info("Configuring Windows interface: %s", interfaceName) - - // Calculate mask string (e.g., 255.255.255.0) - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - // Set the IP address using netsh - cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", interfaceName), - "source=static", - fmt.Sprintf("addr=%s", ip.String()), - fmt.Sprintf("mask=%s", maskIP.String())) - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh command failed: %v, output: %s", err, out) - } - - // Bring up the interface if needed (in Windows, setting the IP usually brings it up) - // But we'll explicitly enable it to be sure - cmd = exec.Command("netsh", "interface", "set", "interface", - interfaceName, - "admin=enable") - - logger.Info("Running command: %v", cmd) - out, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out) - } - - // Wait for the interface to be up and have the correct IP - err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) - if err != nil { - return fmt.Errorf("interface did not come up within timeout: %v", err) - } - - return nil -} - // waitForInterfaceUp polls the network interface until it's up or times out func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) diff --git a/olm/interface_notwindows.go b/olm/interface_notwindows.go new file mode 100644 index 0000000..75e8553 --- /dev/null +++ b/olm/interface_notwindows.go @@ -0,0 +1,12 @@ +//go:build !windows + +package olm + +import ( + "fmt" + "net" +) + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + return fmt.Errorf("configureWindows called on non-Windows platform") +} diff --git a/olm/interface_windows.go b/olm/interface_windows.go new file mode 100644 index 0000000..6427723 --- /dev/null +++ b/olm/interface_windows.go @@ -0,0 +1,60 @@ +//go:build windows + +package olm + +import ( + "fmt" + "net" + "net/netip" + "time" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring Windows interface: %s", interfaceName) + + // Get the LUID for the interface + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err) + } + + // Create the IP address prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ip.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ip) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert IP address") + } + prefix := netip.PrefixFrom(addr, maskBits) + + // Add the IP address to the interface + logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName) + err = luid.AddIPAddress(prefix) + if err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // Wait for the interface to be up and have the correct IP + err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) + if err != nil { + return fmt.Errorf("interface did not come up within timeout: %v", err) + } + + return nil +} diff --git a/olm/route.go b/olm/route.go index 14c18a1..e4e4006 100644 --- a/olm/route.go +++ b/olm/route.go @@ -5,7 +5,6 @@ import ( "net" "os/exec" "runtime" - "strconv" "strings" "github.com/fosrl/newt/logger" @@ -126,106 +125,6 @@ func LinuxRemoveRoute(destination string) error { return nil } -func WindowsAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "windows" { - return nil - } - - var cmd *exec.Cmd - - // Parse destination to get the IP and subnet - ip, ipNet, err := net.ParseCIDR(destination) - if err != nil { - return fmt.Errorf("invalid destination address: %v", err) - } - - // Calculate the subnet mask - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("route", "add", - ip.String(), - "mask", maskIP.String(), - gateway, - "metric", "1") - } else if interfaceName != "" { - // First, get the interface index - indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces") - output, err := indexCmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to get interface index: %v, output: %s", err, output) - } - - // Parse the output to find the interface index - lines := strings.Split(string(output), "\n") - var ifIndex string - for _, line := range lines { - if strings.Contains(line, interfaceName) { - fields := strings.Fields(line) - if len(fields) > 0 { - ifIndex = fields[0] - break - } - } - } - - if ifIndex == "" { - return fmt.Errorf("could not find index for interface %s", interfaceName) - } - - // Convert to integer to validate - idx, err := strconv.Atoi(ifIndex) - if err != nil { - return fmt.Errorf("invalid interface index: %v", err) - } - - // Route via interface using the index - cmd = exec.Command("route", "add", - ip.String(), - "mask", maskIP.String(), - "0.0.0.0", - "if", strconv.Itoa(idx)) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route command failed: %v, output: %s", err, out) - } - - return nil -} - -func WindowsRemoveRoute(destination string) error { - // Parse destination to get the IP - ip, ipNet, err := net.ParseCIDR(destination) - if err != nil { - return fmt.Errorf("invalid destination address: %v", err) - } - - // Calculate the subnet mask - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - cmd := exec.Command("route", "delete", - ip.String(), - "mask", maskIP.String()) - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route delete command failed: %v, output: %s", err, out) - } - - return nil -} - // addRouteForServerIP adds an OS-specific route for the server IP func addRouteForServerIP(serverIP, interfaceName string) error { if err := addRouteForNetworkConfig(serverIP); err != nil { diff --git a/olm/route_notwindows.go b/olm/route_notwindows.go new file mode 100644 index 0000000..910ed26 --- /dev/null +++ b/olm/route_notwindows.go @@ -0,0 +1,11 @@ +//go:build !windows + +package olm + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + return nil +} + +func WindowsRemoveRoute(destination string) error { + return nil +} diff --git a/olm/route_windows.go b/olm/route_windows.go new file mode 100644 index 0000000..c478a04 --- /dev/null +++ b/olm/route_windows.go @@ -0,0 +1,148 @@ +//go:build windows + +package olm + +import ( + "fmt" + "net" + "net/netip" + "runtime" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "windows" { + return nil + } + + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Convert to netip.Prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ipNet.IP.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ipNet.IP) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert destination IP") + } + prefix := netip.PrefixFrom(addr, maskBits) + + var luid winipcfg.LUID + var nextHop netip.Addr + + if interfaceName != "" { + // Get the interface LUID - needed for both gateway and interface-only routes + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err) + } + } + + if gateway != "" { + // Route with specific gateway + gwIP := net.ParseIP(gateway) + if gwIP == nil { + return fmt.Errorf("invalid gateway address: %s", gateway) + } + // Convert to correct IP version + if ip4 := gwIP.To4(); ip4 != nil { + nextHop, _ = netip.AddrFromSlice(ip4) + } else { + nextHop, _ = netip.AddrFromSlice(gwIP) + } + if !nextHop.IsValid() { + return fmt.Errorf("failed to convert gateway IP") + } + logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName) + } else if interfaceName != "" { + // Route via interface only + if addr.Is4() { + nextHop = netip.IPv4Unspecified() + } else { + nextHop = netip.IPv6Unspecified() + } + logger.Info("Adding route to %s via interface %s", destination, interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + // Add the route using winipcfg + err = luid.AddRoute(prefix, nextHop, 1) + if err != nil { + return fmt.Errorf("failed to add route: %v", err) + } + + return nil +} + +func WindowsRemoveRoute(destination string) error { + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Convert to netip.Prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ipNet.IP.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ipNet.IP) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert destination IP") + } + prefix := netip.PrefixFrom(addr, maskBits) + + // Get all routes and find the one to delete + // We need to get the LUID from the existing route + var family winipcfg.AddressFamily + if addr.Is4() { + family = 2 // AF_INET + } else { + family = 23 // AF_INET6 + } + + routes, err := winipcfg.GetIPForwardTable2(family) + if err != nil { + return fmt.Errorf("failed to get route table: %v", err) + } + + // Find and delete matching route + for _, route := range routes { + routePrefix := route.DestinationPrefix.Prefix() + if routePrefix == prefix { + logger.Info("Removing route to %s", destination) + err = route.Delete() + if err != nil { + return fmt.Errorf("failed to delete route: %v", err) + } + return nil + } + } + + return fmt.Errorf("route to %s not found", destination) +}