From 1b1323b553f8688d677eb2a96ab9bdc2b7e4fba0 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 26 Nov 2025 15:06:16 -0500 Subject: [PATCH] Move network to newt - handle --native mode --- clients.go | 24 +-- clients/clients.go | 139 +++++++++++++--- main.go | 6 - network/interface.go | 165 +++++++++++++++++++ network/interface_notwindows.go | 12 ++ network/interface_windows.go | 63 +++++++ network/network.go | 195 ---------------------- network/route.go | 282 ++++++++++++++++++++++++++++++++ network/route_notwindows.go | 11 ++ network/route_windows.go | 148 +++++++++++++++++ network/settings.go | 190 +++++++++++++++++++++ 11 files changed, 990 insertions(+), 245 deletions(-) create mode 100644 network/interface.go create mode 100644 network/interface_notwindows.go create mode 100644 network/interface_windows.go delete mode 100644 network/network.go create mode 100644 network/route.go create mode 100644 network/route_notwindows.go create mode 100644 network/route_windows.go create mode 100644 network/settings.go diff --git a/clients.go b/clients.go index 0696a24..dd5afba 100644 --- a/clients.go +++ b/clients.go @@ -29,19 +29,9 @@ func setupClients(client *websocket.Client) { host = strings.TrimSuffix(host, "/") - if useNativeInterface { - // setupClientsNative(client, host) - } else { - setupClientsNetstack(client, host) - } - - ready = true -} - -func setupClientsNetstack(client *websocket.Client, host string) { logger.Info("Setting up clients with netstack2...") // Create WireGuard service - wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9") + wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9", useNativeInterface) if err != nil { logger.Fatal("Failed to create WireGuard service: %v", err) } @@ -66,6 +56,8 @@ func setupClientsNetstack(client *websocket.Client, host string) { client.OnTokenUpdate(func(token string) { wgService.SetToken(token) }) + + ready = true } func setDownstreamTNetstack(tnet *netstack.Net) { @@ -77,12 +69,10 @@ func setDownstreamTNetstack(tnet *netstack.Net) { func closeClients() { logger.Info("Closing clients...") if wgService != nil { - wgService.Close(!keepInterface) + wgService.Close() wgService = nil } - // closeWgServiceNative() - if wgTesterServer != nil { wgTesterServer.Stop() wgTesterServer = nil @@ -105,8 +95,6 @@ func clientsHandleNewtConnection(publicKey string, endpoint string) { if wgService != nil { wgService.StartHolepunch(publicKey, endpoint) } - - // clientsHandleNewtConnectionNative(publicKey, endpoint) } func clientsOnConnect() { @@ -116,8 +104,6 @@ func clientsOnConnect() { if wgService != nil { wgService.LoadRemoteConfig() } - - // clientsOnConnectNative() } func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { @@ -129,6 +115,4 @@ func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { if wgService != nil { pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) } - - // clientsAddProxyTargetNative(pm, tunnelIp) } diff --git a/clients/clients.go b/clients/clients.go index a029b83..2f4289c 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -9,6 +9,7 @@ import ( "net" "net/netip" "os" + "runtime" "strconv" "strings" "sync" @@ -18,9 +19,11 @@ import ( "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/netstack2" + "github.com/fosrl/newt/network" "github.com/fosrl/newt/util" "github.com/fosrl/newt/websocket" "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -92,11 +95,12 @@ type WireGuardService struct { // Proxy manager for tunnel TunnelIP string // Shared bind and holepunch manager - sharedBind *bind.SharedBind - holePunchManager *holepunch.Manager + sharedBind *bind.SharedBind + holePunchManager *holepunch.Manager + useNativeInterface bool } -func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) { +func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { var key wgtypes.Key var err error @@ -159,17 +163,18 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)} service := &WireGuardService{ - interfaceName: interfaceName, - mtu: mtu, - client: wsClient, - key: key, - keyFilePath: generateAndSaveKeyTo, - newtId: newtId, - host: host, - lastReadings: make(map[string]PeerReading), - Port: port, - dns: dnsAddrs, - sharedBind: sharedBind, + interfaceName: interfaceName, + mtu: mtu, + client: wsClient, + key: key, + keyFilePath: generateAndSaveKeyTo, + newtId: newtId, + host: host, + lastReadings: make(map[string]PeerReading), + Port: port, + dns: dnsAddrs, + sharedBind: sharedBind, + useNativeInterface: useNativeInterface, } // Create the holepunch manager with ResolveDomain function @@ -200,7 +205,7 @@ func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) { s.othertnet = tnet } -func (s *WireGuardService) Close(rm bool) { +func (s *WireGuardService) Close() { if s.stopGetConfig != nil { s.stopGetConfig() s.stopGetConfig = nil @@ -356,11 +361,94 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.holePunchManager.Stop() } - // Parse the IP address from the config - // tunnelIP := netip.MustParseAddr(wgconfig.IpAddress) + var err error + + if s.useNativeInterface { + // Create native TUN device + var interfaceName = s.interfaceName + if runtime.GOOS == "darwin" { + interfaceName, err = network.FindUnusedUTUN() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to find unused utun: %v", err) + } + } + + s.tun, err = tun.CreateTUN(interfaceName, s.mtu) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to create native TUN device: %v", err) + } + + // Get the real interface name (may differ on some platforms) + if realName, err := s.tun.Name(); err == nil { + interfaceName = realName + } + + s.TunnelIP = tunnelIP.String() + // s.tnet is nil for native interface - proxy features not available + s.tnet = nil + + // Create WireGuard device using the shared bind + s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger( + device.LogLevelSilent, + "wireguard: ", + )) + + fileUAPI, err := func() (*os.File, error) { + return ipc.UAPIOpen(interfaceName) + }() + if err != nil { + logger.Error("UAPI listen error: %v", err) + } + + uapiListener, err := ipc.UAPIListen(interfaceName, fileUAPI) + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } + + go func() { + for { + conn, err := uapiListener.Accept() + if err != nil { + + return + } + go s.device.IpcHandle(conn) + } + }() + logger.Info("UAPI listener started") + + // Configure WireGuard with private key + config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String())) + + err = s.device.IpcSet(config) + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to configure WireGuard device: %v", err) + } + + // Bring up the device + err = s.device.Up() + if err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to bring up WireGuard device: %v", err) + } + + // Configure the network interface with IP address + if err := network.ConfigureInterface(interfaceName, wgconfig.IpAddress, s.mtu); err != nil { + s.mu.Unlock() + return fmt.Errorf("failed to configure interface: %v", err) + } + + logger.Info("WireGuard native device created and configured on %s", interfaceName) + + s.mu.Unlock() + return nil + } // Create TUN device and network stack using netstack - var err error s.tun, s.tnet, err = netstack2.CreateNetTUNWithOptions( []netip.Addr{tunnelIP}, s.dns, @@ -383,8 +471,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { "wireguard: ", )) - // logger.Info("Private key is %s", fixKey(s.key.String())) - // Configure WireGuard with private key config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String())) @@ -459,7 +545,9 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { func (s *WireGuardService) ensureTargets(targets []Target) error { if s.tnet == nil { - return fmt.Errorf("netstack not initialized") + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping target configuration - using native interface (no proxy support)") + return nil } for _, target := range targets { @@ -849,7 +937,8 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { } if s.tnet == nil { - logger.Info("Netstack not initialized") + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping add target - using native interface (no proxy support)") return } @@ -908,7 +997,8 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) { } if s.tnet == nil { - logger.Info("Netstack not initialized") + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping remove target - using native interface (no proxy support)") return } @@ -955,7 +1045,8 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { } if s.tnet == nil { - logger.Info("Netstack not initialized") + // Native interface mode - proxy features not available, skip silently + logger.Debug("Skipping update target - using native interface (no proxy support)") return } diff --git a/main.go b/main.go index 329fda7..2f7f9b3 100644 --- a/main.go +++ b/main.go @@ -117,7 +117,6 @@ var ( logLevel string interfaceName string generateAndSaveKeyTo string - keepInterface bool acceptClients bool updownScript string dockerSocket string @@ -178,8 +177,6 @@ func main() { regionEnv := os.Getenv("NEWT_REGION") asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES") - keepInterfaceEnv := os.Getenv("KEEP_INTERFACE") - keepInterface = keepInterfaceEnv == "true" acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS") acceptClients = acceptClientsEnv == "true" useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE") @@ -243,9 +240,6 @@ func main() { if generateAndSaveKeyTo == "" { flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key") } - if keepInterfaceEnv == "" { - flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface") - } if useNativeInterfaceEnv == "" { flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux") } diff --git a/network/interface.go b/network/interface.go new file mode 100644 index 0000000..e110ec1 --- /dev/null +++ b/network/interface.go @@ -0,0 +1,165 @@ +package network + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "runtime" + "strconv" + "time" + + "github.com/fosrl/newt/logger" + "github.com/vishvananda/netlink" +) + +// ConfigureInterface configures a network interface with an IP address and brings it up +func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error { + logger.Info("The tunnel IP is: %s", tunnelIp) + + // Parse the IP address and network + ip, ipNet, err := net.ParseCIDR(tunnelIp) + if err != nil { + return fmt.Errorf("invalid IP address: %v", err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ip.String() + + logger.Debug("The destination address is: %s", destinationAddress) + + // network.SetTunnelRemoteAddress() // what does this do? + SetIPv4Settings([]string{destinationAddress}, []string{mask}) + SetMTU(mtu) + + if interfaceName == "" { + return nil + } + + switch runtime.GOOS { + case "linux": + return configureLinux(interfaceName, ip, ipNet) + case "darwin": + return configureDarwin(interfaceName, ip, ipNet) + case "windows": + return configureWindows(interfaceName, ip, ipNet) + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +// waitForInterfaceUp polls the network interface until it's up or times out +func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { + logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) + deadline := time.Now().Add(timeout) + pollInterval := 500 * time.Millisecond + + for time.Now().Before(deadline) { + // Check if interface exists and is up + iface, err := net.InterfaceByName(interfaceName) + if err == nil { + // Check if interface is up + if iface.Flags&net.FlagUp != 0 { + // Check if it has the expected IP + addrs, err := iface.Addrs() + if err == nil { + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if ok && ipNet.IP.Equal(expectedIP) { + logger.Info("Interface %s is up with correct IP", interfaceName) + return nil // Interface is up with correct IP + } + } + logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName) + } + } else { + logger.Info("Interface %s exists but is not up yet", interfaceName) + } + } else { + logger.Info("Interface %s not found yet: %v", interfaceName, err) + } + + // Wait before next check + time.Sleep(pollInterval) + } + + return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) +} + +func FindUnusedUTUN() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("failed to list interfaces: %v", err) + } + used := make(map[int]bool) + re := regexp.MustCompile(`^utun(\d+)$`) + for _, iface := range ifaces { + if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { + if num, err := strconv.Atoi(matches[1]); err == nil { + used[num] = true + } + } + } + // Try utun0 up to utun255. + for i := 0; i < 256; i++ { + if !used[i] { + return fmt.Sprintf("utun%d", i), nil + } + } + return "", fmt.Errorf("no unused utun interface found") +} + +func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring darwin interface: %s", interfaceName) + + prefix, _ := ipNet.Mask.Size() + ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) + + cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) + } + + // Bring up the interface + cmd = exec.Command("ifconfig", interfaceName, "up") + logger.Info("Running command: %v", cmd) + + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) + } + + return nil +} + +func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + // Get the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + // Create the IP address attributes + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }, + } + + // Add the IP address to the interface + if err := netlink.AddrAdd(link, addr); err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // Bring up the interface + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + return nil +} diff --git a/network/interface_notwindows.go b/network/interface_notwindows.go new file mode 100644 index 0000000..5d15ace --- /dev/null +++ b/network/interface_notwindows.go @@ -0,0 +1,12 @@ +//go:build !windows + +package network + +import ( + "fmt" + "net" +) + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + return fmt.Errorf("configureWindows called on non-Windows platform") +} diff --git a/network/interface_windows.go b/network/interface_windows.go new file mode 100644 index 0000000..966486b --- /dev/null +++ b/network/interface_windows.go @@ -0,0 +1,63 @@ +//go:build windows + +package network + +import ( + "fmt" + "net" + "net/netip" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring Windows interface: %s", interfaceName) + + // Get the LUID for the interface + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err) + } + + // Create the IP address prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ip.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ip) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert IP address") + } + prefix := netip.PrefixFrom(addr, maskBits) + + // Add the IP address to the interface + logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName) + err = luid.AddIPAddress(prefix) + if err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // This was required when we were using the subprocess "netsh" command to bring up the interface. + // With the winipcfg library, the interface should already be up after adding the IP so we dont + // need this step anymore as far as I can tell. + + // // Wait for the interface to be up and have the correct IP + // err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) + // if err != nil { + // return fmt.Errorf("interface did not come up within timeout: %v", err) + // } + + return nil +} diff --git a/network/network.go b/network/network.go deleted file mode 100644 index e359219..0000000 --- a/network/network.go +++ /dev/null @@ -1,195 +0,0 @@ -package network - -import ( - "encoding/binary" - "encoding/json" - "fmt" - "log" - "net" - "time" - - "github.com/google/gopacket" - "github.com/google/gopacket/layers" - "github.com/vishvananda/netlink" - "golang.org/x/net/bpf" - "golang.org/x/net/ipv4" -) - -const ( - udpProtocol = 17 - // EmptyUDPSize is the size of an empty UDP packet - EmptyUDPSize = 28 - timeout = time.Second * 10 -) - -// Server stores data relating to the server -type Server struct { - Hostname string - Addr *net.IPAddr - Port uint16 -} - -// PeerNet stores data about a peer's endpoint -type PeerNet struct { - Resolved bool - IP net.IP - Port uint16 - NewtID string -} - -// GetClientIP gets source ip address that will be used when sending data to dstIP -func GetClientIP(dstIP net.IP) net.IP { - routes, err := netlink.RouteGet(dstIP) - if err != nil { - log.Fatalln("Error getting route:", err) - } - return routes[0].Src -} - -// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr -func HostToAddr(hostStr string) *net.IPAddr { - remoteAddrs, err := net.LookupHost(hostStr) - if err != nil { - log.Fatalln("Error parsing remote address:", err) - } - - for _, addrStr := range remoteAddrs { - if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil { - return remoteAddr - } - } - return nil -} - -// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering -func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn { - packetConn, err := net.ListenPacket("ip4:udp", client.IP.String()) - if err != nil { - log.Fatalln("Error creating packetConn:", err) - } - - rawConn, err := ipv4.NewRawConn(packetConn) - if err != nil { - log.Fatalln("Error creating rawConn:", err) - } - - ApplyBPF(rawConn, server, client) - - return rawConn -} - -// ApplyBPF constructs a BPF program and applies it to the RawConn -func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) { - const ipv4HeaderLen = 20 - const srcIPOffset = 12 - const srcPortOffset = ipv4HeaderLen + 0 - const dstPortOffset = ipv4HeaderLen + 2 - - ipArr := []byte(server.Addr.IP.To4()) - ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3]) - - bpfRaw, err := bpf.Assemble([]bpf.Instruction{ - bpf.LoadAbsolute{Off: srcIPOffset, Size: 4}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0}, - - bpf.LoadAbsolute{Off: srcPortOffset, Size: 2}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0}, - - bpf.LoadAbsolute{Off: dstPortOffset, Size: 2}, - bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0}, - - bpf.RetConstant{Val: 1<<(8*4) - 1}, - bpf.RetConstant{Val: 0}, - }) - - if err != nil { - log.Fatalln("Error assembling BPF:", err) - } - - err = rawConn.SetBPF(bpfRaw) - if err != nil { - log.Fatalln("Error setting BPF:", err) - } -} - -// MakePacket constructs a request packet to send to the server -func MakePacket(payload []byte, server *Server, client *PeerNet) []byte { - buf := gopacket.NewSerializeBuffer() - - opts := gopacket.SerializeOptions{ - FixLengths: true, - ComputeChecksums: true, - } - - ipHeader := layers.IPv4{ - SrcIP: client.IP, - DstIP: server.Addr.IP, - Version: 4, - TTL: 64, - Protocol: layers.IPProtocolUDP, - } - - udpHeader := layers.UDP{ - SrcPort: layers.UDPPort(client.Port), - DstPort: layers.UDPPort(server.Port), - } - - payloadLayer := gopacket.Payload(payload) - - udpHeader.SetNetworkLayerForChecksum(&ipHeader) - - gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer) - - return buf.Bytes() -} - -// SendPacket sends packet to the Server -func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error { - fullPacket := MakePacket(packet, server, client) - _, err := conn.WriteToIP(fullPacket, server.Addr) - return err -} - -// SendDataPacket sends a JSON payload to the Server -func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error { - jsonData, err := json.Marshal(data) - if err != nil { - return fmt.Errorf("failed to marshal payload: %v", err) - } - - return SendPacket(jsonData, conn, server, client) -} - -// RecvPacket receives a UDP packet from server -func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) { - err := conn.SetReadDeadline(time.Now().Add(timeout)) - if err != nil { - return nil, 0, err - } - - response := make([]byte, 4096) - n, err := conn.Read(response) - if err != nil { - return nil, n, err - } - return response, n, nil -} - -// RecvDataPacket receives and unmarshals a JSON packet from server -func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) { - response, n, err := RecvPacket(conn, server, client) - if err != nil { - return nil, err - } - - // Extract payload from UDP packet - payload := response[EmptyUDPSize:n] - return payload, nil -} - -// ParseResponse takes a response packet and parses it into an IP and port -func ParseResponse(response []byte) (net.IP, uint16) { - ip := net.IP(response[:4]) - port := binary.BigEndian.Uint16(response[4:6]) - return ip, port -} diff --git a/network/route.go b/network/route.go new file mode 100644 index 0000000..eb850ee --- /dev/null +++ b/network/route.go @@ -0,0 +1,282 @@ +package network + +import ( + "fmt" + "net" + "os/exec" + "runtime" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/vishvananda/netlink" +) + +func DarwinAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "darwin" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func DarwinRemoveRoute(destination string) error { + if runtime.GOOS != "darwin" { + return nil + } + + cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +func LinuxAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "linux" { + return nil + } + + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Create route + route := &netlink.Route{ + Dst: ipNet, + } + + if gateway != "" { + // Route with specific gateway + gw := net.ParseIP(gateway) + if gw == nil { + return fmt.Errorf("invalid gateway address: %s", gateway) + } + route.Gw = gw + logger.Info("Adding route to %s via gateway %s", destination, gateway) + } else if interfaceName != "" { + // Route via interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + route.LinkIndex = link.Attrs().Index + logger.Info("Adding route to %s via interface %s", destination, interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + // Add the route + if err := netlink.RouteAdd(route); err != nil { + return fmt.Errorf("failed to add route: %v", err) + } + + return nil +} + +func LinuxRemoveRoute(destination string) error { + if runtime.GOOS != "linux" { + return nil + } + + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Create route to delete + route := &netlink.Route{ + Dst: ipNet, + } + + logger.Info("Removing route to %s", destination) + + // Delete the route + if err := netlink.RouteDel(route); err != nil { + return fmt.Errorf("failed to delete route: %v", err) + } + + return nil +} + +// addRouteForServerIP adds an OS-specific route for the server IP +func AddRouteForServerIP(serverIP, interfaceName string) error { + if err := AddRouteForNetworkConfig(serverIP); err != nil { + return err + } + if interfaceName == "" { + return nil + } + if runtime.GOOS == "darwin" { + return DarwinAddRoute(serverIP, "", interfaceName) + } + // else if runtime.GOOS == "windows" { + // return WindowsAddRoute(serverIP, "", interfaceName) + // } else if runtime.GOOS == "linux" { + // return LinuxAddRoute(serverIP, "", interfaceName) + // } + return nil +} + +// removeRouteForServerIP removes an OS-specific route for the server IP +func RemoveRouteForServerIP(serverIP string, interfaceName string) error { + if err := RemoveRouteForNetworkConfig(serverIP); err != nil { + return err + } + if interfaceName == "" { + return nil + } + if runtime.GOOS == "darwin" { + return DarwinRemoveRoute(serverIP) + } + // else if runtime.GOOS == "windows" { + // return WindowsRemoveRoute(serverIP) + // } else if runtime.GOOS == "linux" { + // return LinuxRemoveRoute(serverIP) + // } + return nil +} + +func AddRouteForNetworkConfig(destination string) error { + // Parse the subnet to extract IP and mask + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("failed to parse subnet %s: %v", destination, err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ipNet.IP.String() + + AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + + return nil +} + +func RemoveRouteForNetworkConfig(destination string) error { + // Parse the subnet to extract IP and mask + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("failed to parse subnet %s: %v", destination, err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ipNet.IP.String() + + RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + + return nil +} + +// addRoutes adds routes for each subnet in RemoteSubnets +func AddRoutes(remoteSubnets []string, interfaceName string) error { + if len(remoteSubnets) == 0 { + return nil + } + + // Add routes for each subnet + for _, subnet := range remoteSubnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + if err := AddRouteForNetworkConfig(subnet); err != nil { + logger.Error("Failed to add network config for subnet %s: %v", subnet, err) + continue + } + + // Add route based on operating system + if interfaceName == "" { + continue + } + + if runtime.GOOS == "darwin" { + if err := DarwinAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Added route for remote subnet: %s", subnet) + } + return nil +} + +// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets +func RemoveRoutes(remoteSubnets []string) error { + if len(remoteSubnets) == 0 { + return nil + } + + // Remove routes for each subnet + for _, subnet := range remoteSubnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + if err := RemoveRouteForNetworkConfig(subnet); err != nil { + logger.Error("Failed to remove network config for subnet %s: %v", subnet, err) + continue + } + + // Remove route based on operating system + if runtime.GOOS == "darwin" { + if err := DarwinRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Removed route for remote subnet: %s", subnet) + } + + return nil +} diff --git a/network/route_notwindows.go b/network/route_notwindows.go new file mode 100644 index 0000000..6984c71 --- /dev/null +++ b/network/route_notwindows.go @@ -0,0 +1,11 @@ +//go:build !windows + +package network + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + return nil +} + +func WindowsRemoveRoute(destination string) error { + return nil +} diff --git a/network/route_windows.go b/network/route_windows.go new file mode 100644 index 0000000..ba613b6 --- /dev/null +++ b/network/route_windows.go @@ -0,0 +1,148 @@ +//go:build windows + +package network + +import ( + "fmt" + "net" + "net/netip" + "runtime" + + "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/windows/tunnel/winipcfg" +) + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "windows" { + return nil + } + + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Convert to netip.Prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ipNet.IP.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ipNet.IP) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert destination IP") + } + prefix := netip.PrefixFrom(addr, maskBits) + + var luid winipcfg.LUID + var nextHop netip.Addr + + if interfaceName != "" { + // Get the interface LUID - needed for both gateway and interface-only routes + iface, err := net.InterfaceByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index)) + if err != nil { + return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err) + } + } + + if gateway != "" { + // Route with specific gateway + gwIP := net.ParseIP(gateway) + if gwIP == nil { + return fmt.Errorf("invalid gateway address: %s", gateway) + } + // Convert to correct IP version + if ip4 := gwIP.To4(); ip4 != nil { + nextHop, _ = netip.AddrFromSlice(ip4) + } else { + nextHop, _ = netip.AddrFromSlice(gwIP) + } + if !nextHop.IsValid() { + return fmt.Errorf("failed to convert gateway IP") + } + logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName) + } else if interfaceName != "" { + // Route via interface only + if addr.Is4() { + nextHop = netip.IPv4Unspecified() + } else { + nextHop = netip.IPv6Unspecified() + } + logger.Info("Adding route to %s via interface %s", destination, interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + // Add the route using winipcfg + err = luid.AddRoute(prefix, nextHop, 1) + if err != nil { + return fmt.Errorf("failed to add route: %v", err) + } + + return nil +} + +func WindowsRemoveRoute(destination string) error { + // Parse destination CIDR + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Convert to netip.Prefix + maskBits, _ := ipNet.Mask.Size() + + // Ensure we convert to the correct IP version (IPv4 vs IPv6) + var addr netip.Addr + if ip4 := ipNet.IP.To4(); ip4 != nil { + // IPv4 address + addr, _ = netip.AddrFromSlice(ip4) + } else { + // IPv6 address + addr, _ = netip.AddrFromSlice(ipNet.IP) + } + if !addr.IsValid() { + return fmt.Errorf("failed to convert destination IP") + } + prefix := netip.PrefixFrom(addr, maskBits) + + // Get all routes and find the one to delete + // We need to get the LUID from the existing route + var family winipcfg.AddressFamily + if addr.Is4() { + family = 2 // AF_INET + } else { + family = 23 // AF_INET6 + } + + routes, err := winipcfg.GetIPForwardTable2(family) + if err != nil { + return fmt.Errorf("failed to get route table: %v", err) + } + + // Find and delete matching route + for _, route := range routes { + routePrefix := route.DestinationPrefix.Prefix() + if routePrefix == prefix { + logger.Info("Removing route to %s", destination) + err = route.Delete() + if err != nil { + return fmt.Errorf("failed to delete route: %v", err) + } + return nil + } + } + + return fmt.Errorf("route to %s not found", destination) +} diff --git a/network/settings.go b/network/settings.go new file mode 100644 index 0000000..e7792e0 --- /dev/null +++ b/network/settings.go @@ -0,0 +1,190 @@ +package network + +import ( + "encoding/json" + "sync" + + "github.com/fosrl/newt/logger" +) + +// NetworkSettings represents the network configuration for the tunnel +type NetworkSettings struct { + TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"` + MTU *int `json:"mtu,omitempty"` + DNSServers []string `json:"dns_servers,omitempty"` + IPv4Addresses []string `json:"ipv4_addresses,omitempty"` + IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"` + IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"` + IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"` + IPv6Addresses []string `json:"ipv6_addresses,omitempty"` + IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"` + IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"` + IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"` +} + +// IPv4Route represents an IPv4 route +type IPv4Route struct { + DestinationAddress string `json:"destination_address"` + SubnetMask string `json:"subnet_mask,omitempty"` + GatewayAddress string `json:"gateway_address,omitempty"` + IsDefault bool `json:"is_default,omitempty"` +} + +// IPv6Route represents an IPv6 route +type IPv6Route struct { + DestinationAddress string `json:"destination_address"` + NetworkPrefixLength int `json:"network_prefix_length,omitempty"` + GatewayAddress string `json:"gateway_address,omitempty"` + IsDefault bool `json:"is_default,omitempty"` +} + +var ( + networkSettings NetworkSettings + networkSettingsMutex sync.RWMutex + incrementor int +) + +// SetTunnelRemoteAddress sets the tunnel remote address +func SetTunnelRemoteAddress(address string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.TunnelRemoteAddress = address + incrementor++ + logger.Info("Set tunnel remote address: %s", address) +} + +// SetMTU sets the MTU value +func SetMTU(mtu int) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.MTU = &mtu + incrementor++ + logger.Info("Set MTU: %d", mtu) +} + +// SetDNSServers sets the DNS servers +func SetDNSServers(servers []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.DNSServers = servers + incrementor++ + logger.Info("Set DNS servers: %v", servers) +} + +// SetIPv4Settings sets IPv4 addresses and subnet masks +func SetIPv4Settings(addresses []string, subnetMasks []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4Addresses = addresses + networkSettings.IPv4SubnetMasks = subnetMasks + incrementor++ + logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks) +} + +// SetIPv4IncludedRoutes sets the included IPv4 routes +func SetIPv4IncludedRoutes(routes []IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4IncludedRoutes = routes + incrementor++ + logger.Info("Set IPv4 included routes: %d routes", len(routes)) +} + +func AddIPv4IncludedRoute(route IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + + // make sure it does not already exist + for _, r := range networkSettings.IPv4IncludedRoutes { + if r == route { + logger.Info("IPv4 included route already exists: %+v", route) + return + } + } + + networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route) + incrementor++ + logger.Info("Added IPv4 included route: %+v", route) +} + +func RemoveIPv4IncludedRoute(route IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + routes := networkSettings.IPv4IncludedRoutes + for i, r := range routes { + if r == route { + networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...) + logger.Info("Removed IPv4 included route: %+v", route) + return + } + } + incrementor++ + logger.Info("IPv4 included route not found for removal: %+v", route) +} + +func SetIPv4ExcludedRoutes(routes []IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4ExcludedRoutes = routes + incrementor++ + logger.Info("Set IPv4 excluded routes: %d routes", len(routes)) +} + +// SetIPv6Settings sets IPv6 addresses and network prefixes +func SetIPv6Settings(addresses []string, networkPrefixes []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6Addresses = addresses + networkSettings.IPv6NetworkPrefixes = networkPrefixes + incrementor++ + logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes) +} + +// SetIPv6IncludedRoutes sets the included IPv6 routes +func SetIPv6IncludedRoutes(routes []IPv6Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6IncludedRoutes = routes + incrementor++ + logger.Info("Set IPv6 included routes: %d routes", len(routes)) +} + +// SetIPv6ExcludedRoutes sets the excluded IPv6 routes +func SetIPv6ExcludedRoutes(routes []IPv6Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6ExcludedRoutes = routes + incrementor++ + logger.Info("Set IPv6 excluded routes: %d routes", len(routes)) +} + +// ClearNetworkSettings clears all network settings +func ClearNetworkSettings() { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings = NetworkSettings{} + incrementor++ + logger.Info("Cleared all network settings") +} + +func GetJSON() (string, error) { + networkSettingsMutex.RLock() + defer networkSettingsMutex.RUnlock() + data, err := json.MarshalIndent(networkSettings, "", " ") + if err != nil { + return "", err + } + return string(data), nil +} + +func GetSettings() NetworkSettings { + networkSettingsMutex.RLock() + defer networkSettingsMutex.RUnlock() + return networkSettings +} + +func GetIncrementor() int { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + return incrementor +}