This commit is contained in:
Owen
2025-02-21 18:50:58 -05:00
parent 9d9f10a799
commit 7424caca8a
3 changed files with 196 additions and 2 deletions

186
main.go
View File

@@ -9,16 +9,20 @@ import (
"net"
"os"
"os/signal"
"runtime"
"strconv"
"strings"
"syscall"
"unsafe"
"github.com/fosrl/newt/logger"
"github.com/fosrl/olm/websocket"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -191,6 +195,128 @@ func resolveDomain(domain string) (string, error) {
return ipAddr, nil
}
// ConfigureInterface configures a network interface with an IP address and brings it up
func ConfigureInterface(interfaceName string, ipAddr string) error {
// Parse the IP address and network
ip, ipNet, err := net.ParseCIDR(ipAddr)
if err != nil {
return fmt.Errorf("invalid IP address: %v", err)
}
switch runtime.GOOS {
case "linux":
return configureLinux(interfaceName, ip, ipNet)
case "darwin":
return configureDarwin(interfaceName, ip, ipNet)
default:
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
}
}
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
// Get interface by name
iface, err := net.InterfaceByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
}
// print something using the iface
logger.Info("Interface %s: %v", interfaceName, iface)
// Create socket
fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_DGRAM, 0)
if err != nil {
return fmt.Errorf("failed to create socket: %v", err)
}
defer syscall.Close(fd)
// Prepare interface request structure
ifr := struct {
Name [16]byte
Flags uint16
}{}
copy(ifr.Name[:], interfaceName)
// Get current flags
if err := ioctl(fd, syscall.SIOCGIFFLAGS, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to get interface flags: %v", err)
}
// Set interface up
ifr.Flags |= syscall.IFF_UP | syscall.IFF_RUNNING
if err := ioctl(fd, syscall.SIOCSIFFLAGS, uintptr(unsafe.Pointer(&ifr))); err != nil {
return fmt.Errorf("failed to set interface up: %v", err)
}
// Prepare address structure
var addr syscall.SockaddrInet4
copy(addr.Addr[:], ip.To4())
// Create interface address request
ifra := struct {
Name [16]byte
Addr syscall.RawSockaddrInet4
Mask syscall.RawSockaddrInet4
}{}
copy(ifra.Name[:], interfaceName)
copy(ifra.Addr.Addr[:], ip.To4())
copy(ifra.Mask.Addr[:], ipNet.Mask)
// Set IP address
if err := ioctl(fd, syscall.SIOCSIFADDR, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set interface address: %v", err)
}
// Set netmask
if err := ioctl(fd, syscall.SIOCSIFNETMASK, uintptr(unsafe.Pointer(&ifra))); err != nil {
return fmt.Errorf("failed to set interface netmask: %v", err)
}
return nil
}
// Helper function for ioctl calls
func ioctl(fd int, request uint, argp uintptr) error {
_, _, errno := syscall.Syscall(
syscall.SYS_IOCTL,
uintptr(fd),
uintptr(request),
argp,
)
if errno != 0 {
return os.NewSyscallError("ioctl", errno)
}
return nil
}
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
// Get the interface
link, err := netlink.LinkByName(interfaceName)
if err != nil {
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
}
// Create the IP address attributes
addr := &netlink.Addr{
IPNet: &net.IPNet{
IP: ip,
Mask: ipNet.Mask,
},
}
// Add the IP address to the interface
if err := netlink.AddrAdd(link, addr); err != nil {
return fmt.Errorf("failed to add IP address: %v", err)
}
// Bring up the interface
if err := netlink.LinkSetUp(link); err != nil {
return fmt.Errorf("failed to bring up interface: %v", err)
}
return nil
}
func main() {
var (
endpoint string
@@ -285,6 +411,7 @@ func main() {
var dev *device.Device
// var connected bool
var wgData WgData
var uapi *os.File
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
@@ -309,6 +436,8 @@ func main() {
// return
// }
logger.Info("Received message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
@@ -347,12 +476,61 @@ func main() {
return
}
realInterfaceName, err2 := tdev.Name()
if err2 == nil {
interfaceName = realInterfaceName
}
// open UAPI file (or use supplied fd)
fileUAPI, err := func() (*os.File, error) {
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
if uapiFdStr == "" {
return ipc.UAPIOpen(interfaceName)
}
// use supplied fd
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil
}()
if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
// Create WireGuard device
dev = device.NewDevice(tdev, conn.NewDefaultBind(), device.NewLogger(
mapToWireGuardLogLevel(loggerLevel),
"wireguard: ",
))
errs := make(chan error)
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() {
for {
conn, err := uapi.Accept()
if err != nil {
errs <- err
return
}
go dev.IpcHandle(conn)
}
}()
logger.Info("UAPI listener started")
endpoint, err := resolveDomain(wgData.Endpoint)
if err != nil {
logger.Error("Failed to resolve endpoint: %v", err)
@@ -377,6 +555,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Error("Failed to bring up WireGuard device: %v", err)
}
// configure the interface
err = ConfigureInterface(realInterfaceName, wgData.TunnelIP)
if err != nil {
logger.Error("Failed to configure interface: %v", err)
}
logger.Info("WireGuard device created.")
// Ping to bring the tunnel up on the server side quickly
// ping(tnet, wgData.ServerIP)
@@ -420,6 +604,6 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
// Cleanup
uapi.Close()
dev.Close()
}