diff --git a/go.mod b/go.mod index da5328e..0e8d553 100644 --- a/go.mod +++ b/go.mod @@ -7,13 +7,15 @@ toolchain go1.23.2 require ( github.com/fosrl/newt v0.0.0-20250215225251-76503f3f2cd8 golang.org/x/net v0.33.0 - golang.org/x/sys v0.28.0 + golang.org/x/sys v0.30.0 golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 ) require ( github.com/gorilla/websocket v1.5.3 // indirect + github.com/vishvananda/netlink v1.3.0 // indirect + github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/crypto v0.31.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect ) diff --git a/go.sum b/go.sum index 227184b..4997342 100644 --- a/go.sum +++ b/go.sum @@ -6,12 +6,20 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/time v0.7.0 h1:ntUhktv3OPE6TgYxXWv9vKvUSJyIFJlyohwbkEwPrKQ= golang.org/x/time v0.7.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= diff --git a/main.go b/main.go index 299ef05..0a08413 100644 --- a/main.go +++ b/main.go @@ -9,16 +9,20 @@ import ( "net" "os" "os/signal" + "runtime" "strconv" "strings" "syscall" + "unsafe" "github.com/fosrl/newt/logger" "github.com/fosrl/olm/websocket" + "github.com/vishvananda/netlink" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -191,6 +195,128 @@ func resolveDomain(domain string) (string, error) { return ipAddr, nil } +// ConfigureInterface configures a network interface with an IP address and brings it up +func ConfigureInterface(interfaceName string, ipAddr string) error { + // Parse the IP address and network + ip, ipNet, err := net.ParseCIDR(ipAddr) + if err != nil { + return fmt.Errorf("invalid IP address: %v", err) + } + + switch runtime.GOOS { + case "linux": + return configureLinux(interfaceName, ip, ipNet) + case "darwin": + return configureDarwin(interfaceName, ip, ipNet) + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + // Get interface by name + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + // print something using the iface + logger.Info("Interface %s: %v", interfaceName, iface) + + // Create socket + fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0) + if err != nil { + return fmt.Errorf("failed to create socket: %v", err) + } + defer syscall.Close(fd) + + // Prepare interface request structure + ifr := struct { + Name [16]byte + Flags uint16 + }{} + copy(ifr.Name[:], interfaceName) + + // Get current flags + if err := ioctl(fd, syscall.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to get interface flags: %v", err) + } + + // Set interface up + ifr.Flags |= syscall.IFF_UP | syscall.IFF_RUNNING + if err := ioctl(fd, syscall.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifr))); err != nil { + return fmt.Errorf("failed to set interface up: %v", err) + } + + // Prepare address structure + var addr syscall.SockaddrInet4 + copy(addr.Addr[:], ip.To4()) + + // Create interface address request + ifra := struct { + Name [16]byte + Addr syscall.RawSockaddrInet4 + Mask syscall.RawSockaddrInet4 + }{} + copy(ifra.Name[:], interfaceName) + copy(ifra.Addr.Addr[:], ip.To4()) + copy(ifra.Mask.Addr[:], ipNet.Mask) + + // Set IP address + if err := ioctl(fd, syscall.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil { + return fmt.Errorf("failed to set interface address: %v", err) + } + + // Set netmask + if err := ioctl(fd, syscall.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil { + return fmt.Errorf("failed to set interface netmask: %v", err) + } + + return nil +} + +// Helper function for ioctl calls +func ioctl(fd int, request uint, argp uintptr) error { + _, _, errno := syscall.Syscall( + syscall.SYS_IOCTL, + uintptr(fd), + uintptr(request), + argp, + ) + if errno != 0 { + return os.NewSyscallError("ioctl", errno) + } + return nil +} + +func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + // Get the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + // Create the IP address attributes + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }, + } + + // Add the IP address to the interface + if err := netlink.AddrAdd(link, addr); err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // Bring up the interface + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + return nil +} + func main() { var ( endpoint string @@ -285,6 +411,7 @@ func main() { var dev *device.Device // var connected bool var wgData WgData + var uapi *os.File olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") @@ -309,6 +436,8 @@ func main() { // return // } + logger.Info("Received message: %v", msg.Data) + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Info("Error marshaling data: %v", err) @@ -347,12 +476,61 @@ func main() { return } + realInterfaceName, err2 := tdev.Name() + if err2 == nil { + interfaceName = realInterfaceName + } + + // open UAPI file (or use supplied fd) + + fileUAPI, err := func() (*os.File, error) { + uapiFdStr := os.Getenv(ENV_WG_UAPI_FD) + if uapiFdStr == "" { + return ipc.UAPIOpen(interfaceName) + } + + // use supplied fd + + fd, err := strconv.ParseUint(uapiFdStr, 10, 32) + if err != nil { + return nil, err + } + + return os.NewFile(uintptr(fd), ""), nil + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } + // Create WireGuard device dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger( mapToWireGuardLogLevel(loggerLevel), "wireguard: ", )) + errs := make(chan error) + + uapi, err := ipc.UAPIListen(interfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := uapi.Accept() + if err != nil { + errs <- err + return + } + go dev.IpcHandle(conn) + } + }() + + logger.Info("UAPI listener started") + endpoint, err := resolveDomain(wgData.Endpoint) if err != nil { logger.Error("Failed to resolve endpoint: %v", err) @@ -377,6 +555,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Error("Failed to bring up WireGuard device: %v", err) } + // configure the interface + err = ConfigureInterface(realInterfaceName, wgData.TunnelIP) + if err != nil { + logger.Error("Failed to configure interface: %v", err) + } + logger.Info("WireGuard device created.") // Ping to bring the tunnel up on the server side quickly // ping(tnet, wgData.ServerIP) @@ -420,6 +604,6 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh - // Cleanup + uapi.Close() dev.Close() }