mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Convert windows working not using netsh route
Former-commit-id: e238ee4d69
This commit is contained in:
3
go.mod
3
go.mod
@@ -5,18 +5,19 @@ go 1.25
|
||||
require (
|
||||
github.com/Microsoft/go-winio v0.6.2
|
||||
github.com/fosrl/newt v0.0.0
|
||||
github.com/godbus/dbus/v5 v5.2.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/miekg/dns v1.1.68
|
||||
github.com/vishvananda/netlink v1.3.1
|
||||
golang.org/x/sys v0.38.0
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/godbus/dbus/v5 v5.2.0 // indirect
|
||||
github.com/google/btree v1.1.3 // indirect
|
||||
github.com/vishvananda/netns v0.0.5 // indirect
|
||||
golang.org/x/crypto v0.44.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -38,6 +38,8 @@ golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+Z
|
||||
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
|
||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c h1:m/r7OM+Y2Ty1sgBQ7Qb27VgIMBW8ZZhT4gLnUyDIhzI=
|
||||
gvisor.dev/gvisor v0.0.0-20250503011706-39ed1f5ac29c/go.mod h1:3r5CMtNQMKIvBlrmM9xWUNamjKBYPOWyXOjmg5Kts3g=
|
||||
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
|
||||
|
||||
@@ -51,48 +51,6 @@ func ConfigureInterface(interfaceName string, wgData WgData, mtu int) error {
|
||||
}
|
||||
}
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||
|
||||
// Calculate mask string (e.g., 255.255.255.0)
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
mask := net.CIDRMask(maskBits, 32)
|
||||
maskIP := net.IP(mask)
|
||||
|
||||
// Set the IP address using netsh
|
||||
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
|
||||
fmt.Sprintf("name=%s", interfaceName),
|
||||
"source=static",
|
||||
fmt.Sprintf("addr=%s", ip.String()),
|
||||
fmt.Sprintf("mask=%s", maskIP.String()))
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
|
||||
// But we'll explicitly enable it to be sure
|
||||
cmd = exec.Command("netsh", "interface", "set", "interface",
|
||||
interfaceName,
|
||||
"admin=enable")
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err = cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
// Wait for the interface to be up and have the correct IP
|
||||
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||
|
||||
12
olm/interface_notwindows.go
Normal file
12
olm/interface_notwindows.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build !windows
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
return fmt.Errorf("configureWindows called on non-Windows platform")
|
||||
}
|
||||
60
olm/interface_windows.go
Normal file
60
olm/interface_windows.go
Normal file
@@ -0,0 +1,60 @@
|
||||
//go:build windows
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||
|
||||
// Get the LUID for the interface
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
// Create the IP address prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ip.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ip)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert IP address")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
// Add the IP address to the interface
|
||||
logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName)
|
||||
err = luid.AddIPAddress(prefix)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add IP address: %v", err)
|
||||
}
|
||||
|
||||
// Wait for the interface to be up and have the correct IP
|
||||
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
101
olm/route.go
101
olm/route.go
@@ -5,7 +5,6 @@ import (
|
||||
"net"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
@@ -126,106 +125,6 @@ func LinuxRemoveRoute(destination string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
|
||||
// Parse destination to get the IP and subnet
|
||||
ip, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Calculate the subnet mask
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
mask := net.CIDRMask(maskBits, 32)
|
||||
maskIP := net.IP(mask)
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
cmd = exec.Command("route", "add",
|
||||
ip.String(),
|
||||
"mask", maskIP.String(),
|
||||
gateway,
|
||||
"metric", "1")
|
||||
} else if interfaceName != "" {
|
||||
// First, get the interface index
|
||||
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
|
||||
output, err := indexCmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
|
||||
}
|
||||
|
||||
// Parse the output to find the interface index
|
||||
lines := strings.Split(string(output), "\n")
|
||||
var ifIndex string
|
||||
for _, line := range lines {
|
||||
if strings.Contains(line, interfaceName) {
|
||||
fields := strings.Fields(line)
|
||||
if len(fields) > 0 {
|
||||
ifIndex = fields[0]
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ifIndex == "" {
|
||||
return fmt.Errorf("could not find index for interface %s", interfaceName)
|
||||
}
|
||||
|
||||
// Convert to integer to validate
|
||||
idx, err := strconv.Atoi(ifIndex)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid interface index: %v", err)
|
||||
}
|
||||
|
||||
// Route via interface using the index
|
||||
cmd = exec.Command("route", "add",
|
||||
ip.String(),
|
||||
"mask", maskIP.String(),
|
||||
"0.0.0.0",
|
||||
"if", strconv.Itoa(idx))
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
// Parse destination to get the IP
|
||||
ip, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Calculate the subnet mask
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
mask := net.CIDRMask(maskBits, 32)
|
||||
maskIP := net.IP(mask)
|
||||
|
||||
cmd := exec.Command("route", "delete",
|
||||
ip.String(),
|
||||
"mask", maskIP.String())
|
||||
|
||||
logger.Info("Running command: %v", cmd)
|
||||
out, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||
func addRouteForServerIP(serverIP, interfaceName string) error {
|
||||
if err := addRouteForNetworkConfig(serverIP); err != nil {
|
||||
|
||||
11
olm/route_notwindows.go
Normal file
11
olm/route_notwindows.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !windows
|
||||
|
||||
package olm
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
return nil
|
||||
}
|
||||
148
olm/route_windows.go
Normal file
148
olm/route_windows.go
Normal file
@@ -0,0 +1,148 @@
|
||||
//go:build windows
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||
)
|
||||
|
||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||
if runtime.GOOS != "windows" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Convert to netip.Prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert destination IP")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
var luid winipcfg.LUID
|
||||
var nextHop netip.Addr
|
||||
|
||||
if interfaceName != "" {
|
||||
// Get the interface LUID - needed for both gateway and interface-only routes
|
||||
iface, err := net.InterfaceByName(interfaceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||
}
|
||||
|
||||
luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if gateway != "" {
|
||||
// Route with specific gateway
|
||||
gwIP := net.ParseIP(gateway)
|
||||
if gwIP == nil {
|
||||
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||
}
|
||||
// Convert to correct IP version
|
||||
if ip4 := gwIP.To4(); ip4 != nil {
|
||||
nextHop, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
nextHop, _ = netip.AddrFromSlice(gwIP)
|
||||
}
|
||||
if !nextHop.IsValid() {
|
||||
return fmt.Errorf("failed to convert gateway IP")
|
||||
}
|
||||
logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName)
|
||||
} else if interfaceName != "" {
|
||||
// Route via interface only
|
||||
if addr.Is4() {
|
||||
nextHop = netip.IPv4Unspecified()
|
||||
} else {
|
||||
nextHop = netip.IPv6Unspecified()
|
||||
}
|
||||
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||
} else {
|
||||
return fmt.Errorf("either gateway or interface must be specified")
|
||||
}
|
||||
|
||||
// Add the route using winipcfg
|
||||
err = luid.AddRoute(prefix, nextHop, 1)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add route: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func WindowsRemoveRoute(destination string) error {
|
||||
// Parse destination CIDR
|
||||
_, ipNet, err := net.ParseCIDR(destination)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid destination address: %v", err)
|
||||
}
|
||||
|
||||
// Convert to netip.Prefix
|
||||
maskBits, _ := ipNet.Mask.Size()
|
||||
|
||||
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||
var addr netip.Addr
|
||||
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||
// IPv4 address
|
||||
addr, _ = netip.AddrFromSlice(ip4)
|
||||
} else {
|
||||
// IPv6 address
|
||||
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||
}
|
||||
if !addr.IsValid() {
|
||||
return fmt.Errorf("failed to convert destination IP")
|
||||
}
|
||||
prefix := netip.PrefixFrom(addr, maskBits)
|
||||
|
||||
// Get all routes and find the one to delete
|
||||
// We need to get the LUID from the existing route
|
||||
var family winipcfg.AddressFamily
|
||||
if addr.Is4() {
|
||||
family = 2 // AF_INET
|
||||
} else {
|
||||
family = 23 // AF_INET6
|
||||
}
|
||||
|
||||
routes, err := winipcfg.GetIPForwardTable2(family)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get route table: %v", err)
|
||||
}
|
||||
|
||||
// Find and delete matching route
|
||||
for _, route := range routes {
|
||||
routePrefix := route.DestinationPrefix.Prefix()
|
||||
if routePrefix == prefix {
|
||||
logger.Info("Removing route to %s", destination)
|
||||
err = route.Delete()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete route: %v", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("route to %s not found", destination)
|
||||
}
|
||||
Reference in New Issue
Block a user