From 13e7f55b3059184b97e314909c1c253cfdd9b34b Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 3 May 2025 16:41:13 -0400 Subject: [PATCH] Interface comes up --- common.go | 279 ++++++++++++++++++++++++++++++++++ main.go | 25 +--- main_windows.go | 389 ------------------------------------------------ unix.go | 156 +++---------------- windows.go | 25 ++++ 5 files changed, 325 insertions(+), 549 deletions(-) delete mode 100644 main_windows.go create mode 100644 windows.go diff --git a/common.go b/common.go index bc00faa..728152f 100644 --- a/common.go +++ b/common.go @@ -6,12 +6,17 @@ import ( "encoding/json" "fmt" "net" + "os/exec" + "regexp" + "runtime" + "strconv" "strings" "time" "github.com/fosrl/newt/logger" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" + "github.com/vishvananda/netlink" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/curve25519" "golang.org/x/exp/rand" @@ -535,3 +540,277 @@ func RemovePeer(dev *device.Device, siteId int, publicKey string) error { return nil } + +// ConfigureInterface configures a network interface with an IP address and brings it up +func ConfigureInterface(interfaceName string, wgData WgData) error { + var ipAddr string = wgData.TunnelIP + + // Parse the IP address and network + ip, ipNet, err := net.ParseCIDR(ipAddr) + if err != nil { + return fmt.Errorf("invalid IP address: %v", err) + } + + switch runtime.GOOS { + case "linux": + return configureLinux(interfaceName, ip, ipNet) + case "darwin": + return configureDarwin(interfaceName, ip, ipNet) + case "windows": + return configureWindows(interfaceName, ip, ipNet) + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring Windows interface: %s", interfaceName) + + // Calculate mask string (e.g., 255.255.255.0) + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + // Set the IP address using netsh + cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", + fmt.Sprintf("name=%s", interfaceName), + "source=static", + fmt.Sprintf("addr=%s", ip.String()), + fmt.Sprintf("mask=%s", maskIP.String())) + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("netsh command failed: %v, output: %s", err, out) + } + + // Bring up the interface if needed (in Windows, setting the IP usually brings it up) + // But we'll explicitly enable it to be sure + cmd = exec.Command("netsh", "interface", "set", "interface", + fmt.Sprintf("%s", interfaceName), + "admin=enable") + + logger.Info("Running command: %v", cmd) + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out) + } + + return nil +} + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + var cmd *exec.Cmd + + // Parse destination to get the IP and subnet + ip, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Calculate the subnet mask + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "add", + ip.String(), + "mask", maskIP.String(), + gateway, + "metric", "1") + } else if interfaceName != "" { + // First, get the interface index + indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces") + output, err := indexCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to get interface index: %v, output: %s", err, output) + } + + // Parse the output to find the interface index + lines := strings.Split(string(output), "\n") + var ifIndex string + for _, line := range lines { + if strings.Contains(line, interfaceName) { + fields := strings.Fields(line) + if len(fields) > 0 { + ifIndex = fields[0] + break + } + } + } + + if ifIndex == "" { + return fmt.Errorf("could not find index for interface %s", interfaceName) + } + + // Convert to integer to validate + idx, err := strconv.Atoi(ifIndex) + if err != nil { + return fmt.Errorf("invalid interface index: %v", err) + } + + // Route via interface using the index + cmd = exec.Command("route", "add", + ip.String(), + "mask", maskIP.String(), + "0.0.0.0", + "if", strconv.Itoa(idx)) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func WindowsRemoveRoute(destination string) error { + // Parse destination to get the IP + ip, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Calculate the subnet mask + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + cmd := exec.Command("route", "delete", + ip.String(), + "mask", maskIP.String()) + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +func findUnusedUTUN() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("failed to list interfaces: %v", err) + } + used := make(map[int]bool) + re := regexp.MustCompile(`^utun(\d+)$`) + for _, iface := range ifaces { + if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { + if num, err := strconv.Atoi(matches[1]); err == nil { + used[num] = true + } + } + } + // Try utun0 up to utun255. + for i := 0; i < 256; i++ { + if !used[i] { + return fmt.Sprintf("utun%d", i), nil + } + } + return "", fmt.Errorf("no unused utun interface found") +} + +func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring darwin interface: %s", interfaceName) + + prefix, _ := ipNet.Mask.Size() + ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) + + cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) + } + + // Bring up the interface + cmd = exec.Command("ifconfig", interfaceName, "up") + logger.Info("Running command: %v", cmd) + + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) + } + + return nil +} + +func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + // Get the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + // Create the IP address attributes + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }, + } + + // Add the IP address to the interface + if err := netlink.AddrAdd(link, addr); err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // Bring up the interface + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + return nil +} + +func DarwinAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "darwin" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func DarwinRemoveRoute(destination string) error { + if runtime.GOOS != "darwin" { + return nil + } + + cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} diff --git a/main.go b/main.go index 0d0ee55..3559cb2 100644 --- a/main.go +++ b/main.go @@ -1,5 +1,3 @@ -//go:build !windows - package main import ( @@ -18,9 +16,7 @@ import ( "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" - "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -205,7 +201,6 @@ func main() { return } - // NEED TO DETERMINE AVAILABLE TUN DEVICE HERE tdev, err = func() (tun.Device, error) { tunFdStr := os.Getenv(ENV_WG_TUN_FD) @@ -222,20 +217,7 @@ func main() { return tun.CreateTUN(interfaceName, mtuInt) } - // construct tun device from supplied fd - - fd, err := strconv.ParseUint(tunFdStr, 10, 32) - if err != nil { - return nil, err - } - - err = unix.SetNonblock(int(fd), true) - if err != nil { - return nil, err - } - - file := os.NewFile(uintptr(fd), "") - return tun.CreateTUNFromFile(file, mtuInt) + return createTUNFromFD(tunFdStr, mtuInt) }() if err != nil { @@ -249,11 +231,10 @@ func main() { } // open UAPI file (or use supplied fd) - fileUAPI, err := func() (*os.File, error) { uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) if uapiFdStr == "" { - return ipc.UAPIOpen(interfaceName) + return uapiOpen(interfaceName) } // use supplied fd @@ -278,7 +259,7 @@ func main() { errs := make(chan error) - uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) + uapi, err := uapiListen(interfaceName, fileUAPI) if err != nil { logger.Error("Failed to listen on uapi socket: %v", err) os.Exit(1) diff --git a/main_windows.go b/main_windows.go deleted file mode 100644 index c4134e8..0000000 --- a/main_windows.go +++ /dev/null @@ -1,389 +0,0 @@ -//go:build windows - -package main - -import ( - "encoding/json" - "flag" - "fmt" - "net" - "os" - "os/exec" - "os/signal" - "runtime" - "strconv" - "syscall" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/olm/websocket" - "golang.org/x/sys/windows" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/ipc" - "golang.zx2c4.com/wireguard/tun" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" -) - -// ConfigureInterface configures a network interface with an IP address and brings it up -func ConfigureInterface(interfaceName string, wgData WgData) error { - var ipAddr string = wgData.TunnelIP - var destIP string = wgData.ServerIP - - if runtime.GOOS == "windows" { - return configureWindows(interfaceName, ipAddr, destIP) - } - - return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) -} - -func configureWindows(interfaceName string, ipAddr, destIP string) error { - logger.Info("Configuring Windows interface: %s", interfaceName) - - // Parse the IP address and network - ip, ipNet, err := net.ParseCIDR(ipAddr) - if err != nil { - return fmt.Errorf("invalid IP address: %v", err) - } - - // Set the IP address using netsh - // Windows uses the 'netsh' command to configure network interfaces - maskBits, _ := ipNet.Mask.Size() - - // create a mask string like 255.255.255.0 from the maskBits - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", interfaceName), - "source=static", - fmt.Sprintf("addr=%s", ip.String()), - fmt.Sprintf("mask=%s", maskIP.String())) - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh command failed: %v, output: %s", err, out) - } - - // Add a route to the destination IP - cmd = exec.Command("netsh", "interface", "ipv4", "add", "route", - fmt.Sprintf("%s/32", destIP), - fmt.Sprintf("interface=%s", interfaceName), - "metric=1") - - logger.Info("Running command: %v", cmd) - out, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh route command failed: %v, output: %s", err, out) - } - - return nil -} - -func main() { - var ( - endpoint string - id string - secret string - mtu string - mtuInt int - dns string - privateKey wgtypes.Key - err error - logLevel string - interfaceName string - ) - - stopHolepunch = make(chan struct{}) - stopRegister = make(chan struct{}) - - // Check OS - if runtime.GOOS != "windows" { - fmt.Println("This version of olm is only for Windows systems") - os.Exit(1) - } - - // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values - endpoint = os.Getenv("PANGOLIN_ENDPOINT") - id = os.Getenv("OLM_ID") - secret = os.Getenv("OLM_SECRET") - mtu = os.Getenv("MTU") - dns = os.Getenv("DNS") - logLevel = os.Getenv("LOG_LEVEL") - interfaceName = os.Getenv("INTERFACE") - - if endpoint == "" { - flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") - } - if id == "" { - flag.StringVar(&id, "id", "", "Olm ID") - } - if secret == "" { - flag.StringVar(&secret, "secret", "", "Olm secret") - } - if mtu == "" { - flag.StringVar(&mtu, "mtu", "1280", "MTU to use") - } - if dns == "" { - flag.StringVar(&dns, "dns", "8.8.8.8", "DNS server to use") - } - if logLevel == "" { - flag.StringVar(&logLevel, "log-level", "INFO", "Log level (DEBUG, INFO, WARN, ERROR, FATAL)") - } - if interfaceName == "" { - flag.StringVar(&interfaceName, "interface", "olm", "Name of the WireGuard interface") - } - - // do a --version check - version := flag.Bool("version", false, "Print the version") - flag.Parse() - - if *version { - fmt.Println("Olm Windows version replaceme") - os.Exit(0) - } - - logger.Init() - loggerLevel := parseLogLevel(logLevel) - logger.GetLogger().SetLevel(parseLogLevel(logLevel)) - - // parse the mtu string into an int - mtuInt, err = strconv.Atoi(mtu) - if err != nil { - logger.Fatal("Failed to parse MTU: %v", err) - } - - privateKey, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Fatal("Failed to generate private key: %v", err) - } - - // Create a new olm - olm, err := websocket.NewClient( - id, // CLI arg takes precedence - secret, // CLI arg takes precedence - endpoint, - ) - if err != nil { - logger.Fatal("Failed to create olm: %v", err) - } - - sourcePort, err := FindAvailableUDPPort(49152, 65535) - if err != nil { - fmt.Printf("Error finding available port: %v\n", err) - os.Exit(1) - } - - // Create TUN device and network stack - var dev *device.Device - var wgData WgData - var holePunchData HolePunchData - var uapi net.Listener - var tdev tun.Device - - olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - olm.Close() - }) - - olm.RegisterHandler("olm/wg/update", func(msg websocket.WSMessage) { - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - endpoint, err := resolveDomain(wgData.Endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return - } - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s - public_key=%s - allowed_ip=%s/32 - endpoint=%s - persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) - err = dev.IpcSet(config) - if err != nil { - logger.Error("Failed to configure WireGuard device: %v", err) - } - }) - - // Register handlers for different message types - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { - logger.Info("Received message: %v", msg.Data) - close(stopRegister) - // if there is an existing tunnel then close it - if dev != nil { - logger.Info("Got new message. Closing existing tunnel!") - dev.Close() - } - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - if err := json.Unmarshal(jsonData, &wgData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - // Windows-specific TUN device creation - tdev, err = tun.CreateTUN(interfaceName, mtuInt) - if err != nil { - logger.Error("Failed to create TUN device: %v", err) - return - } - - realInterfaceName, err2 := tdev.Name() - if err2 == nil { - interfaceName = realInterfaceName - } - - // Create the WireGuard device - dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger( - mapToWireGuardLogLevel(loggerLevel), - "wireguard: ", - )) - - // Setup UAPI for Windows - uapi, err = ipc.UAPIListen(interfaceName) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } - - errs := make(chan error) - go func() { - for { - conn, err := uapi.Accept() - if err != nil { - errs <- err - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") - - host, err := resolveDomain(wgData.Endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return - } - - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s -public_key=%s -allowed_ip=%s/32 -endpoint=%s -persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) - err = dev.IpcSet(config) - if err != nil { - logger.Error("Failed to configure WireGuard device: %v", err) - } - - // Bring up the device - err = dev.Up() - if err != nil { - logger.Error("Failed to bring up WireGuard device: %v", err) - } - - // Configure the interface - err = ConfigureInterface(realInterfaceName, wgData) - if err != nil { - logger.Error("Failed to configure interface: %v", err) - } - - close(stopHolepunch) - // Monitor the connection for activity - monitorConnection(dev, func() { - host, err := resolveDomain(endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return - } - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s -public_key=%s -allowed_ip=%s/32 -endpoint=%s:21820 -persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) - err = dev.IpcSet(config) - if err != nil { - logger.Error("Failed to configure WireGuard device: %v", err) - } - logger.Info("Adjusted to point to relay!") - }) - logger.Info("WireGuard device created.") - }) - - olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { - logger.Info("Received message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &holePunchData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - gerbilServerPubKey = holePunchData.ServerPubKey - }) - - olm.OnConnect(func() error { - publicKey := privateKey.PublicKey() - logger.Debug("Public key: %s", publicKey) - - go keepSendingRegistration(olm, publicKey.String()) - - logger.Info("Sent registration message") - return nil - }) - - olm.OnTokenUpdate(func(token string) { - olmToken = token - }) - - // Connect to the WebSocket server - if err := olm.Connect(); err != nil { - logger.Fatal("Failed to connect to server: %v", err) - } - defer olm.Close() - - go keepSendingUDPHolePunch(endpoint, id, sourcePort) - - // Wait for interrupt signal - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, windows.SIGTERM) - <-sigCh - - select { - case <-stopHolepunch: - // Channel already closed, do nothing - default: - close(stopHolepunch) - } - - select { - case <-stopRegister: - // Channel already closed - default: - close(stopRegister) - } - - if uapi != nil { - uapi.Close() - } - - if dev != nil { - dev.Close() - } -} diff --git a/unix.go b/unix.go index dad06c0..3a9c09e 100644 --- a/unix.go +++ b/unix.go @@ -3,153 +3,33 @@ package main import ( - "fmt" "net" - "os/exec" - "regexp" - "runtime" + "os" "strconv" - "github.com/fosrl/newt/logger" - "github.com/vishvananda/netlink" + "golang.org/x/sys/unix" + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" ) -// ConfigureInterface configures a network interface with an IP address and brings it up -func ConfigureInterface(interfaceName string, wgData WgData) error { - var ipAddr string = wgData.TunnelIP - - // Parse the IP address and network - ip, ipNet, err := net.ParseCIDR(ipAddr) +func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { + fd, err := strconv.ParseUint(tunFdStr, 10, 32) if err != nil { - return fmt.Errorf("invalid IP address: %v", err) + return nil, err } - switch runtime.GOOS { - case "linux": - return configureLinux(interfaceName, ip, ipNet) - case "darwin": - return configureDarwin(interfaceName, ip, ipNet) - default: - return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + err = unix.SetNonblock(int(fd), true) + if err != nil { + return nil, err } + + file := os.NewFile(uintptr(fd), "") + return tun.CreateTUNFromFile(file, mtuInt) +} +func uapiOpen(interfaceName string) (*os.File, error) { + return ipc.UAPIOpen(interfaceName) } -func findUnusedUTUN() (string, error) { - ifaces, err := net.Interfaces() - if err != nil { - return "", fmt.Errorf("failed to list interfaces: %v", err) - } - used := make(map[int]bool) - re := regexp.MustCompile(`^utun(\d+)$`) - for _, iface := range ifaces { - if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { - if num, err := strconv.Atoi(matches[1]); err == nil { - used[num] = true - } - } - } - // Try utun0 up to utun255. - for i := 0; i < 256; i++ { - if !used[i] { - return fmt.Sprintf("utun%d", i), nil - } - } - return "", fmt.Errorf("no unused utun interface found") -} - -func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - logger.Info("Configuring darwin interface: %s", interfaceName) - - prefix, _ := ipNet.Mask.Size() - ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) - - cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) - } - - // Bring up the interface - cmd = exec.Command("ifconfig", interfaceName, "up") - logger.Info("Running command: %v", cmd) - - out, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) - } - - return nil -} - -func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - // Get the interface - link, err := netlink.LinkByName(interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) - } - - // Create the IP address attributes - addr := &netlink.Addr{ - IPNet: &net.IPNet{ - IP: ip, - Mask: ipNet.Mask, - }, - } - - // Add the IP address to the interface - if err := netlink.AddrAdd(link, addr); err != nil { - return fmt.Errorf("failed to add IP address: %v", err) - } - - // Bring up the interface - if err := netlink.LinkSetUp(link); err != nil { - return fmt.Errorf("failed to bring up interface: %v", err) - } - - return nil -} - -func DarwinAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "darwin" { - return nil - } - - var cmd *exec.Cmd - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) - } else if interfaceName != "" { - // Route via interface - cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route command failed: %v, output: %s", err, out) - } - - return nil -} - -func DarwinRemoveRoute(destination string) error { - if runtime.GOOS != "darwin" { - return nil - } - - cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route delete command failed: %v, output: %s", err, out) - } - - return nil +func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + return ipc.UAPIListen(interfaceName, fileUAPI) } diff --git a/windows.go b/windows.go new file mode 100644 index 0000000..032096b --- /dev/null +++ b/windows.go @@ -0,0 +1,25 @@ +//go:build windows + +package main + +import ( + "errors" + "net" + "os" + + "golang.zx2c4.com/wireguard/ipc" + "golang.zx2c4.com/wireguard/tun" +) + +func createTUNFromFD(tunFdStr string, mtuInt int) (tun.Device, error) { + return nil, errors.New("CreateTUNFromFile not supported on Windows") +} + +func uapiOpen(interfaceName string) (*os.File, error) { + return nil, nil +} + +func uapiListen(interfaceName string, fileUAPI *os.File) (net.Listener, error) { + // On Windows, UAPIListen only takes one parameter + return ipc.UAPIListen(interfaceName) +}