mirror of
https://github.com/fosrl/newt.git
synced 2026-02-07 21:46:39 +00:00
Move network to newt - handle --native mode
This commit is contained in:
24
clients.go
24
clients.go
@@ -29,19 +29,9 @@ func setupClients(client *websocket.Client) {
|
|||||||
|
|
||||||
host = strings.TrimSuffix(host, "/")
|
host = strings.TrimSuffix(host, "/")
|
||||||
|
|
||||||
if useNativeInterface {
|
|
||||||
// setupClientsNative(client, host)
|
|
||||||
} else {
|
|
||||||
setupClientsNetstack(client, host)
|
|
||||||
}
|
|
||||||
|
|
||||||
ready = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupClientsNetstack(client *websocket.Client, host string) {
|
|
||||||
logger.Info("Setting up clients with netstack2...")
|
logger.Info("Setting up clients with netstack2...")
|
||||||
// Create WireGuard service
|
// Create WireGuard service
|
||||||
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9")
|
wgService, err = wgnetstack.NewWireGuardService(interfaceName, mtuInt, generateAndSaveKeyTo, host, id, client, "9.9.9.9", useNativeInterface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create WireGuard service: %v", err)
|
logger.Fatal("Failed to create WireGuard service: %v", err)
|
||||||
}
|
}
|
||||||
@@ -66,6 +56,8 @@ func setupClientsNetstack(client *websocket.Client, host string) {
|
|||||||
client.OnTokenUpdate(func(token string) {
|
client.OnTokenUpdate(func(token string) {
|
||||||
wgService.SetToken(token)
|
wgService.SetToken(token)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
ready = true
|
||||||
}
|
}
|
||||||
|
|
||||||
func setDownstreamTNetstack(tnet *netstack.Net) {
|
func setDownstreamTNetstack(tnet *netstack.Net) {
|
||||||
@@ -77,12 +69,10 @@ func setDownstreamTNetstack(tnet *netstack.Net) {
|
|||||||
func closeClients() {
|
func closeClients() {
|
||||||
logger.Info("Closing clients...")
|
logger.Info("Closing clients...")
|
||||||
if wgService != nil {
|
if wgService != nil {
|
||||||
wgService.Close(!keepInterface)
|
wgService.Close()
|
||||||
wgService = nil
|
wgService = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeWgServiceNative()
|
|
||||||
|
|
||||||
if wgTesterServer != nil {
|
if wgTesterServer != nil {
|
||||||
wgTesterServer.Stop()
|
wgTesterServer.Stop()
|
||||||
wgTesterServer = nil
|
wgTesterServer = nil
|
||||||
@@ -105,8 +95,6 @@ func clientsHandleNewtConnection(publicKey string, endpoint string) {
|
|||||||
if wgService != nil {
|
if wgService != nil {
|
||||||
wgService.StartHolepunch(publicKey, endpoint)
|
wgService.StartHolepunch(publicKey, endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientsHandleNewtConnectionNative(publicKey, endpoint)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func clientsOnConnect() {
|
func clientsOnConnect() {
|
||||||
@@ -116,8 +104,6 @@ func clientsOnConnect() {
|
|||||||
if wgService != nil {
|
if wgService != nil {
|
||||||
wgService.LoadRemoteConfig()
|
wgService.LoadRemoteConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientsOnConnectNative()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) {
|
func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) {
|
||||||
@@ -129,6 +115,4 @@ func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) {
|
|||||||
if wgService != nil {
|
if wgService != nil {
|
||||||
pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
|
pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
// clientsAddProxyTargetNative(pm, tunnelIp)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -18,9 +19,11 @@ import (
|
|||||||
"github.com/fosrl/newt/holepunch"
|
"github.com/fosrl/newt/holepunch"
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/netstack2"
|
"github.com/fosrl/newt/netstack2"
|
||||||
|
"github.com/fosrl/newt/network"
|
||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/newt/websocket"
|
"github.com/fosrl/newt/websocket"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -92,11 +95,12 @@ type WireGuardService struct {
|
|||||||
// Proxy manager for tunnel
|
// Proxy manager for tunnel
|
||||||
TunnelIP string
|
TunnelIP string
|
||||||
// Shared bind and holepunch manager
|
// Shared bind and holepunch manager
|
||||||
sharedBind *bind.SharedBind
|
sharedBind *bind.SharedBind
|
||||||
holePunchManager *holepunch.Manager
|
holePunchManager *holepunch.Manager
|
||||||
|
useNativeInterface bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string) (*WireGuardService, error) {
|
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
|
||||||
var key wgtypes.Key
|
var key wgtypes.Key
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
@@ -159,17 +163,18 @@ func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo str
|
|||||||
dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)}
|
dnsAddrs := []netip.Addr{netip.MustParseAddr(dns)}
|
||||||
|
|
||||||
service := &WireGuardService{
|
service := &WireGuardService{
|
||||||
interfaceName: interfaceName,
|
interfaceName: interfaceName,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
client: wsClient,
|
client: wsClient,
|
||||||
key: key,
|
key: key,
|
||||||
keyFilePath: generateAndSaveKeyTo,
|
keyFilePath: generateAndSaveKeyTo,
|
||||||
newtId: newtId,
|
newtId: newtId,
|
||||||
host: host,
|
host: host,
|
||||||
lastReadings: make(map[string]PeerReading),
|
lastReadings: make(map[string]PeerReading),
|
||||||
Port: port,
|
Port: port,
|
||||||
dns: dnsAddrs,
|
dns: dnsAddrs,
|
||||||
sharedBind: sharedBind,
|
sharedBind: sharedBind,
|
||||||
|
useNativeInterface: useNativeInterface,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create the holepunch manager with ResolveDomain function
|
// Create the holepunch manager with ResolveDomain function
|
||||||
@@ -200,7 +205,7 @@ func (s *WireGuardService) SetOthertnet(tnet *netstack.Net) {
|
|||||||
s.othertnet = tnet
|
s.othertnet = tnet
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) Close(rm bool) {
|
func (s *WireGuardService) Close() {
|
||||||
if s.stopGetConfig != nil {
|
if s.stopGetConfig != nil {
|
||||||
s.stopGetConfig()
|
s.stopGetConfig()
|
||||||
s.stopGetConfig = nil
|
s.stopGetConfig = nil
|
||||||
@@ -356,11 +361,94 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
s.holePunchManager.Stop()
|
s.holePunchManager.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the IP address from the config
|
var err error
|
||||||
// tunnelIP := netip.MustParseAddr(wgconfig.IpAddress)
|
|
||||||
|
if s.useNativeInterface {
|
||||||
|
// Create native TUN device
|
||||||
|
var interfaceName = s.interfaceName
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
interfaceName, err = network.FindUnusedUTUN()
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("failed to find unused utun: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.tun, err = tun.CreateTUN(interfaceName, s.mtu)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("failed to create native TUN device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the real interface name (may differ on some platforms)
|
||||||
|
if realName, err := s.tun.Name(); err == nil {
|
||||||
|
interfaceName = realName
|
||||||
|
}
|
||||||
|
|
||||||
|
s.TunnelIP = tunnelIP.String()
|
||||||
|
// s.tnet is nil for native interface - proxy features not available
|
||||||
|
s.tnet = nil
|
||||||
|
|
||||||
|
// Create WireGuard device using the shared bind
|
||||||
|
s.device = device.NewDevice(s.tun, s.sharedBind, device.NewLogger(
|
||||||
|
device.LogLevelSilent,
|
||||||
|
"wireguard: ",
|
||||||
|
))
|
||||||
|
|
||||||
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
|
return ipc.UAPIOpen(interfaceName)
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("UAPI listen error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
uapiListener, err := ipc.UAPIListen(interfaceName, fileUAPI)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
conn, err := uapiListener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go s.device.IpcHandle(conn)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
|
// Configure WireGuard with private key
|
||||||
|
config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String()))
|
||||||
|
|
||||||
|
err = s.device.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("failed to configure WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the device
|
||||||
|
err = s.device.Up()
|
||||||
|
if err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("failed to bring up WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure the network interface with IP address
|
||||||
|
if err := network.ConfigureInterface(interfaceName, wgconfig.IpAddress, s.mtu); err != nil {
|
||||||
|
s.mu.Unlock()
|
||||||
|
return fmt.Errorf("failed to configure interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("WireGuard native device created and configured on %s", interfaceName)
|
||||||
|
|
||||||
|
s.mu.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Create TUN device and network stack using netstack
|
// Create TUN device and network stack using netstack
|
||||||
var err error
|
|
||||||
s.tun, s.tnet, err = netstack2.CreateNetTUNWithOptions(
|
s.tun, s.tnet, err = netstack2.CreateNetTUNWithOptions(
|
||||||
[]netip.Addr{tunnelIP},
|
[]netip.Addr{tunnelIP},
|
||||||
s.dns,
|
s.dns,
|
||||||
@@ -383,8 +471,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
|
|||||||
"wireguard: ",
|
"wireguard: ",
|
||||||
))
|
))
|
||||||
|
|
||||||
// logger.Info("Private key is %s", fixKey(s.key.String()))
|
|
||||||
|
|
||||||
// Configure WireGuard with private key
|
// Configure WireGuard with private key
|
||||||
config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String()))
|
config := fmt.Sprintf("private_key=%s", util.FixKey(s.key.String()))
|
||||||
|
|
||||||
@@ -459,7 +545,9 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error {
|
|||||||
|
|
||||||
func (s *WireGuardService) ensureTargets(targets []Target) error {
|
func (s *WireGuardService) ensureTargets(targets []Target) error {
|
||||||
if s.tnet == nil {
|
if s.tnet == nil {
|
||||||
return fmt.Errorf("netstack not initialized")
|
// Native interface mode - proxy features not available, skip silently
|
||||||
|
logger.Debug("Skipping target configuration - using native interface (no proxy support)")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
@@ -849,7 +937,8 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.tnet == nil {
|
if s.tnet == nil {
|
||||||
logger.Info("Netstack not initialized")
|
// Native interface mode - proxy features not available, skip silently
|
||||||
|
logger.Debug("Skipping add target - using native interface (no proxy support)")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -908,7 +997,8 @@ func (s *WireGuardService) handleRemoveTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.tnet == nil {
|
if s.tnet == nil {
|
||||||
logger.Info("Netstack not initialized")
|
// Native interface mode - proxy features not available, skip silently
|
||||||
|
logger.Debug("Skipping remove target - using native interface (no proxy support)")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -955,7 +1045,8 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if s.tnet == nil {
|
if s.tnet == nil {
|
||||||
logger.Info("Netstack not initialized")
|
// Native interface mode - proxy features not available, skip silently
|
||||||
|
logger.Debug("Skipping update target - using native interface (no proxy support)")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
6
main.go
6
main.go
@@ -117,7 +117,6 @@ var (
|
|||||||
logLevel string
|
logLevel string
|
||||||
interfaceName string
|
interfaceName string
|
||||||
generateAndSaveKeyTo string
|
generateAndSaveKeyTo string
|
||||||
keepInterface bool
|
|
||||||
acceptClients bool
|
acceptClients bool
|
||||||
updownScript string
|
updownScript string
|
||||||
dockerSocket string
|
dockerSocket string
|
||||||
@@ -178,8 +177,6 @@ func main() {
|
|||||||
regionEnv := os.Getenv("NEWT_REGION")
|
regionEnv := os.Getenv("NEWT_REGION")
|
||||||
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
asyncBytesEnv := os.Getenv("NEWT_METRICS_ASYNC_BYTES")
|
||||||
|
|
||||||
keepInterfaceEnv := os.Getenv("KEEP_INTERFACE")
|
|
||||||
keepInterface = keepInterfaceEnv == "true"
|
|
||||||
acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS")
|
acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS")
|
||||||
acceptClients = acceptClientsEnv == "true"
|
acceptClients = acceptClientsEnv == "true"
|
||||||
useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE")
|
useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE")
|
||||||
@@ -243,9 +240,6 @@ func main() {
|
|||||||
if generateAndSaveKeyTo == "" {
|
if generateAndSaveKeyTo == "" {
|
||||||
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
flag.StringVar(&generateAndSaveKeyTo, "generateAndSaveKeyTo", "", "Path to save generated private key")
|
||||||
}
|
}
|
||||||
if keepInterfaceEnv == "" {
|
|
||||||
flag.BoolVar(&keepInterface, "keep-interface", false, "Keep the WireGuard interface")
|
|
||||||
}
|
|
||||||
if useNativeInterfaceEnv == "" {
|
if useNativeInterfaceEnv == "" {
|
||||||
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
|
flag.BoolVar(&useNativeInterface, "native", false, "Use native WireGuard interface (requires WireGuard kernel module) and linux")
|
||||||
}
|
}
|
||||||
|
|||||||
165
network/interface.go
Normal file
165
network/interface.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||||
|
func ConfigureInterface(interfaceName string, tunnelIp string, mtu int) error {
|
||||||
|
logger.Info("The tunnel IP is: %s", tunnelIp)
|
||||||
|
|
||||||
|
// Parse the IP address and network
|
||||||
|
ip, ipNet, err := net.ParseCIDR(tunnelIp)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||||
|
mask := net.IP(ipNet.Mask).String()
|
||||||
|
destinationAddress := ip.String()
|
||||||
|
|
||||||
|
logger.Debug("The destination address is: %s", destinationAddress)
|
||||||
|
|
||||||
|
// network.SetTunnelRemoteAddress() // what does this do?
|
||||||
|
SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||||
|
SetMTU(mtu)
|
||||||
|
|
||||||
|
if interfaceName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return configureLinux(interfaceName, ip, ipNet)
|
||||||
|
case "darwin":
|
||||||
|
return configureDarwin(interfaceName, ip, ipNet)
|
||||||
|
case "windows":
|
||||||
|
return configureWindows(interfaceName, ip, ipNet)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||||
|
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||||
|
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
pollInterval := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
// Check if interface exists and is up
|
||||||
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
|
if err == nil {
|
||||||
|
// Check if interface is up
|
||||||
|
if iface.Flags&net.FlagUp != 0 {
|
||||||
|
// Check if it has the expected IP
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err == nil {
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipNet, ok := addr.(*net.IPNet)
|
||||||
|
if ok && ipNet.IP.Equal(expectedIP) {
|
||||||
|
logger.Info("Interface %s is up with correct IP", interfaceName)
|
||||||
|
return nil // Interface is up with correct IP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait before next check
|
||||||
|
time.Sleep(pollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindUnusedUTUN() (string, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||||
|
}
|
||||||
|
used := make(map[int]bool)
|
||||||
|
re := regexp.MustCompile(`^utun(\d+)$`)
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
||||||
|
if num, err := strconv.Atoi(matches[1]); err == nil {
|
||||||
|
used[num] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try utun0 up to utun255.
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
if !used[i] {
|
||||||
|
return fmt.Sprintf("utun%d", i), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no unused utun interface found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring darwin interface: %s", interfaceName)
|
||||||
|
|
||||||
|
prefix, _ := ipNet.Mask.Size()
|
||||||
|
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||||
|
|
||||||
|
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err = cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
// Get the interface
|
||||||
|
link, err := netlink.LinkByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the IP address attributes
|
||||||
|
addr := &netlink.Addr{
|
||||||
|
IPNet: &net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: ipNet.Mask,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the IP address to the interface
|
||||||
|
if err := netlink.AddrAdd(link, addr); err != nil {
|
||||||
|
return fmt.Errorf("failed to add IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
if err := netlink.LinkSetUp(link); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
12
network/interface_notwindows.go
Normal file
12
network/interface_notwindows.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
return fmt.Errorf("configureWindows called on non-Windows platform")
|
||||||
|
}
|
||||||
63
network/interface_windows.go
Normal file
63
network/interface_windows.go
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||||
|
|
||||||
|
// Get the LUID for the interface
|
||||||
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
luid, err := winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the IP address prefix
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
|
||||||
|
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||||
|
var addr netip.Addr
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
// IPv4 address
|
||||||
|
addr, _ = netip.AddrFromSlice(ip4)
|
||||||
|
} else {
|
||||||
|
// IPv6 address
|
||||||
|
addr, _ = netip.AddrFromSlice(ip)
|
||||||
|
}
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return fmt.Errorf("failed to convert IP address")
|
||||||
|
}
|
||||||
|
prefix := netip.PrefixFrom(addr, maskBits)
|
||||||
|
|
||||||
|
// Add the IP address to the interface
|
||||||
|
logger.Info("Adding IP address %s to interface %s", prefix.String(), interfaceName)
|
||||||
|
err = luid.AddIPAddress(prefix)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This was required when we were using the subprocess "netsh" command to bring up the interface.
|
||||||
|
// With the winipcfg library, the interface should already be up after adding the IP so we dont
|
||||||
|
// need this step anymore as far as I can tell.
|
||||||
|
|
||||||
|
// // Wait for the interface to be up and have the correct IP
|
||||||
|
// err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
package network
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"log"
|
|
||||||
"net"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
|
||||||
"github.com/google/gopacket/layers"
|
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
"golang.org/x/net/bpf"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
udpProtocol = 17
|
|
||||||
// EmptyUDPSize is the size of an empty UDP packet
|
|
||||||
EmptyUDPSize = 28
|
|
||||||
timeout = time.Second * 10
|
|
||||||
)
|
|
||||||
|
|
||||||
// Server stores data relating to the server
|
|
||||||
type Server struct {
|
|
||||||
Hostname string
|
|
||||||
Addr *net.IPAddr
|
|
||||||
Port uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeerNet stores data about a peer's endpoint
|
|
||||||
type PeerNet struct {
|
|
||||||
Resolved bool
|
|
||||||
IP net.IP
|
|
||||||
Port uint16
|
|
||||||
NewtID string
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetClientIP gets source ip address that will be used when sending data to dstIP
|
|
||||||
func GetClientIP(dstIP net.IP) net.IP {
|
|
||||||
routes, err := netlink.RouteGet(dstIP)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error getting route:", err)
|
|
||||||
}
|
|
||||||
return routes[0].Src
|
|
||||||
}
|
|
||||||
|
|
||||||
// HostToAddr resolves a hostname, whether DNS or IP to a valid net.IPAddr
|
|
||||||
func HostToAddr(hostStr string) *net.IPAddr {
|
|
||||||
remoteAddrs, err := net.LookupHost(hostStr)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error parsing remote address:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, addrStr := range remoteAddrs {
|
|
||||||
if remoteAddr, err := net.ResolveIPAddr("ip4", addrStr); err == nil {
|
|
||||||
return remoteAddr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetupRawConn creates an ipv4 and udp only RawConn and applies packet filtering
|
|
||||||
func SetupRawConn(server *Server, client *PeerNet) *ipv4.RawConn {
|
|
||||||
packetConn, err := net.ListenPacket("ip4:udp", client.IP.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error creating packetConn:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
rawConn, err := ipv4.NewRawConn(packetConn)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error creating rawConn:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ApplyBPF(rawConn, server, client)
|
|
||||||
|
|
||||||
return rawConn
|
|
||||||
}
|
|
||||||
|
|
||||||
// ApplyBPF constructs a BPF program and applies it to the RawConn
|
|
||||||
func ApplyBPF(rawConn *ipv4.RawConn, server *Server, client *PeerNet) {
|
|
||||||
const ipv4HeaderLen = 20
|
|
||||||
const srcIPOffset = 12
|
|
||||||
const srcPortOffset = ipv4HeaderLen + 0
|
|
||||||
const dstPortOffset = ipv4HeaderLen + 2
|
|
||||||
|
|
||||||
ipArr := []byte(server.Addr.IP.To4())
|
|
||||||
ipInt := uint32(ipArr[0])<<(3*8) + uint32(ipArr[1])<<(2*8) + uint32(ipArr[2])<<8 + uint32(ipArr[3])
|
|
||||||
|
|
||||||
bpfRaw, err := bpf.Assemble([]bpf.Instruction{
|
|
||||||
bpf.LoadAbsolute{Off: srcIPOffset, Size: 4},
|
|
||||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: ipInt, SkipFalse: 5, SkipTrue: 0},
|
|
||||||
|
|
||||||
bpf.LoadAbsolute{Off: srcPortOffset, Size: 2},
|
|
||||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(server.Port), SkipFalse: 3, SkipTrue: 0},
|
|
||||||
|
|
||||||
bpf.LoadAbsolute{Off: dstPortOffset, Size: 2},
|
|
||||||
bpf.JumpIf{Cond: bpf.JumpEqual, Val: uint32(client.Port), SkipFalse: 1, SkipTrue: 0},
|
|
||||||
|
|
||||||
bpf.RetConstant{Val: 1<<(8*4) - 1},
|
|
||||||
bpf.RetConstant{Val: 0},
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error assembling BPF:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rawConn.SetBPF(bpfRaw)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalln("Error setting BPF:", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MakePacket constructs a request packet to send to the server
|
|
||||||
func MakePacket(payload []byte, server *Server, client *PeerNet) []byte {
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
|
||||||
|
|
||||||
opts := gopacket.SerializeOptions{
|
|
||||||
FixLengths: true,
|
|
||||||
ComputeChecksums: true,
|
|
||||||
}
|
|
||||||
|
|
||||||
ipHeader := layers.IPv4{
|
|
||||||
SrcIP: client.IP,
|
|
||||||
DstIP: server.Addr.IP,
|
|
||||||
Version: 4,
|
|
||||||
TTL: 64,
|
|
||||||
Protocol: layers.IPProtocolUDP,
|
|
||||||
}
|
|
||||||
|
|
||||||
udpHeader := layers.UDP{
|
|
||||||
SrcPort: layers.UDPPort(client.Port),
|
|
||||||
DstPort: layers.UDPPort(server.Port),
|
|
||||||
}
|
|
||||||
|
|
||||||
payloadLayer := gopacket.Payload(payload)
|
|
||||||
|
|
||||||
udpHeader.SetNetworkLayerForChecksum(&ipHeader)
|
|
||||||
|
|
||||||
gopacket.SerializeLayers(buf, opts, &ipHeader, &udpHeader, &payloadLayer)
|
|
||||||
|
|
||||||
return buf.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendPacket sends packet to the Server
|
|
||||||
func SendPacket(packet []byte, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
|
||||||
fullPacket := MakePacket(packet, server, client)
|
|
||||||
_, err := conn.WriteToIP(fullPacket, server.Addr)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// SendDataPacket sends a JSON payload to the Server
|
|
||||||
func SendDataPacket(data interface{}, conn *ipv4.RawConn, server *Server, client *PeerNet) error {
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to marshal payload: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SendPacket(jsonData, conn, server, client)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RecvPacket receives a UDP packet from server
|
|
||||||
func RecvPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, int, error) {
|
|
||||||
err := conn.SetReadDeadline(time.Now().Add(timeout))
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
response := make([]byte, 4096)
|
|
||||||
n, err := conn.Read(response)
|
|
||||||
if err != nil {
|
|
||||||
return nil, n, err
|
|
||||||
}
|
|
||||||
return response, n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RecvDataPacket receives and unmarshals a JSON packet from server
|
|
||||||
func RecvDataPacket(conn *ipv4.RawConn, server *Server, client *PeerNet) ([]byte, error) {
|
|
||||||
response, n, err := RecvPacket(conn, server, client)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract payload from UDP packet
|
|
||||||
payload := response[EmptyUDPSize:n]
|
|
||||||
return payload, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParseResponse takes a response packet and parses it into an IP and port
|
|
||||||
func ParseResponse(response []byte) (net.IP, uint16) {
|
|
||||||
ip := net.IP(response[:4])
|
|
||||||
port := binary.BigEndian.Uint16(response[4:6])
|
|
||||||
return ip, port
|
|
||||||
}
|
|
||||||
282
network/route.go
Normal file
282
network/route.go
Normal file
@@ -0,0 +1,282 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DarwinRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse destination CIDR
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create route
|
||||||
|
route := &netlink.Route{
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
gw := net.ParseIP(gateway)
|
||||||
|
if gw == nil {
|
||||||
|
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||||
|
}
|
||||||
|
route.Gw = gw
|
||||||
|
logger.Info("Adding route to %s via gateway %s", destination, gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
link, err := netlink.LinkByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
route.LinkIndex = link.Attrs().Index
|
||||||
|
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the route
|
||||||
|
if err := netlink.RouteAdd(route); err != nil {
|
||||||
|
return fmt.Errorf("failed to add route: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse destination CIDR
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create route to delete
|
||||||
|
route := &netlink.Route{
|
||||||
|
Dst: ipNet,
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removing route to %s", destination)
|
||||||
|
|
||||||
|
// Delete the route
|
||||||
|
if err := netlink.RouteDel(route); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete route: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||||
|
func AddRouteForServerIP(serverIP, interfaceName string) error {
|
||||||
|
if err := AddRouteForNetworkConfig(serverIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if interfaceName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||||
|
func RemoveRouteForServerIP(serverIP string, interfaceName string) error {
|
||||||
|
if err := RemoveRouteForNetworkConfig(serverIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if interfaceName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinRemoveRoute(serverIP)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsRemoveRoute(serverIP)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxRemoveRoute(serverIP)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddRouteForNetworkConfig(destination string) error {
|
||||||
|
// Parse the subnet to extract IP and mask
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||||
|
mask := net.IP(ipNet.Mask).String()
|
||||||
|
destinationAddress := ipNet.IP.String()
|
||||||
|
|
||||||
|
AddIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func RemoveRouteForNetworkConfig(destination string) error {
|
||||||
|
// Parse the subnet to extract IP and mask
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||||
|
mask := net.IP(ipNet.Mask).String()
|
||||||
|
destinationAddress := ipNet.IP.String()
|
||||||
|
|
||||||
|
RemoveIPv4IncludedRoute(IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoutes adds routes for each subnet in RemoteSubnets
|
||||||
|
func AddRoutes(remoteSubnets []string, interfaceName string) error {
|
||||||
|
if len(remoteSubnets) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add routes for each subnet
|
||||||
|
for _, subnet := range remoteSubnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := AddRouteForNetworkConfig(subnet); err != nil {
|
||||||
|
logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add route based on operating system
|
||||||
|
if interfaceName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Added route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets
|
||||||
|
func RemoveRoutes(remoteSubnets []string) error {
|
||||||
|
if len(remoteSubnets) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove routes for each subnet
|
||||||
|
for _, subnet := range remoteSubnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := RemoveRouteForNetworkConfig(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
11
network/route_notwindows.go
Normal file
11
network/route_notwindows.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package network
|
||||||
|
|
||||||
|
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsRemoveRoute(destination string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
148
network/route_windows.go
Normal file
148
network/route_windows.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"golang.zx2c4.com/wireguard/windows/tunnel/winipcfg"
|
||||||
|
)
|
||||||
|
|
||||||
|
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse destination CIDR
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to netip.Prefix
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
|
||||||
|
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||||
|
var addr netip.Addr
|
||||||
|
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||||
|
// IPv4 address
|
||||||
|
addr, _ = netip.AddrFromSlice(ip4)
|
||||||
|
} else {
|
||||||
|
// IPv6 address
|
||||||
|
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||||
|
}
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return fmt.Errorf("failed to convert destination IP")
|
||||||
|
}
|
||||||
|
prefix := netip.PrefixFrom(addr, maskBits)
|
||||||
|
|
||||||
|
var luid winipcfg.LUID
|
||||||
|
var nextHop netip.Addr
|
||||||
|
|
||||||
|
if interfaceName != "" {
|
||||||
|
// Get the interface LUID - needed for both gateway and interface-only routes
|
||||||
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
luid, err = winipcfg.LUIDFromIndex(uint32(iface.Index))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get LUID for interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
gwIP := net.ParseIP(gateway)
|
||||||
|
if gwIP == nil {
|
||||||
|
return fmt.Errorf("invalid gateway address: %s", gateway)
|
||||||
|
}
|
||||||
|
// Convert to correct IP version
|
||||||
|
if ip4 := gwIP.To4(); ip4 != nil {
|
||||||
|
nextHop, _ = netip.AddrFromSlice(ip4)
|
||||||
|
} else {
|
||||||
|
nextHop, _ = netip.AddrFromSlice(gwIP)
|
||||||
|
}
|
||||||
|
if !nextHop.IsValid() {
|
||||||
|
return fmt.Errorf("failed to convert gateway IP")
|
||||||
|
}
|
||||||
|
logger.Info("Adding route to %s via gateway %s on interface %s", destination, gateway, interfaceName)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface only
|
||||||
|
if addr.Is4() {
|
||||||
|
nextHop = netip.IPv4Unspecified()
|
||||||
|
} else {
|
||||||
|
nextHop = netip.IPv6Unspecified()
|
||||||
|
}
|
||||||
|
logger.Info("Adding route to %s via interface %s", destination, interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the route using winipcfg
|
||||||
|
err = luid.AddRoute(prefix, nextHop, 1)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to add route: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsRemoveRoute(destination string) error {
|
||||||
|
// Parse destination CIDR
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to netip.Prefix
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
|
||||||
|
// Ensure we convert to the correct IP version (IPv4 vs IPv6)
|
||||||
|
var addr netip.Addr
|
||||||
|
if ip4 := ipNet.IP.To4(); ip4 != nil {
|
||||||
|
// IPv4 address
|
||||||
|
addr, _ = netip.AddrFromSlice(ip4)
|
||||||
|
} else {
|
||||||
|
// IPv6 address
|
||||||
|
addr, _ = netip.AddrFromSlice(ipNet.IP)
|
||||||
|
}
|
||||||
|
if !addr.IsValid() {
|
||||||
|
return fmt.Errorf("failed to convert destination IP")
|
||||||
|
}
|
||||||
|
prefix := netip.PrefixFrom(addr, maskBits)
|
||||||
|
|
||||||
|
// Get all routes and find the one to delete
|
||||||
|
// We need to get the LUID from the existing route
|
||||||
|
var family winipcfg.AddressFamily
|
||||||
|
if addr.Is4() {
|
||||||
|
family = 2 // AF_INET
|
||||||
|
} else {
|
||||||
|
family = 23 // AF_INET6
|
||||||
|
}
|
||||||
|
|
||||||
|
routes, err := winipcfg.GetIPForwardTable2(family)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get route table: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find and delete matching route
|
||||||
|
for _, route := range routes {
|
||||||
|
routePrefix := route.DestinationPrefix.Prefix()
|
||||||
|
if routePrefix == prefix {
|
||||||
|
logger.Info("Removing route to %s", destination)
|
||||||
|
err = route.Delete()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete route: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("route to %s not found", destination)
|
||||||
|
}
|
||||||
190
network/settings.go
Normal file
190
network/settings.go
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetworkSettings represents the network configuration for the tunnel
|
||||||
|
type NetworkSettings struct {
|
||||||
|
TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"`
|
||||||
|
MTU *int `json:"mtu,omitempty"`
|
||||||
|
DNSServers []string `json:"dns_servers,omitempty"`
|
||||||
|
IPv4Addresses []string `json:"ipv4_addresses,omitempty"`
|
||||||
|
IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"`
|
||||||
|
IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"`
|
||||||
|
IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"`
|
||||||
|
IPv6Addresses []string `json:"ipv6_addresses,omitempty"`
|
||||||
|
IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"`
|
||||||
|
IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"`
|
||||||
|
IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv4Route represents an IPv4 route
|
||||||
|
type IPv4Route struct {
|
||||||
|
DestinationAddress string `json:"destination_address"`
|
||||||
|
SubnetMask string `json:"subnet_mask,omitempty"`
|
||||||
|
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||||
|
IsDefault bool `json:"is_default,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv6Route represents an IPv6 route
|
||||||
|
type IPv6Route struct {
|
||||||
|
DestinationAddress string `json:"destination_address"`
|
||||||
|
NetworkPrefixLength int `json:"network_prefix_length,omitempty"`
|
||||||
|
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||||
|
IsDefault bool `json:"is_default,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
networkSettings NetworkSettings
|
||||||
|
networkSettingsMutex sync.RWMutex
|
||||||
|
incrementor int
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetTunnelRemoteAddress sets the tunnel remote address
|
||||||
|
func SetTunnelRemoteAddress(address string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.TunnelRemoteAddress = address
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set tunnel remote address: %s", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMTU sets the MTU value
|
||||||
|
func SetMTU(mtu int) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.MTU = &mtu
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set MTU: %d", mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDNSServers sets the DNS servers
|
||||||
|
func SetDNSServers(servers []string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.DNSServers = servers
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set DNS servers: %v", servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv4Settings sets IPv4 addresses and subnet masks
|
||||||
|
func SetIPv4Settings(addresses []string, subnetMasks []string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv4Addresses = addresses
|
||||||
|
networkSettings.IPv4SubnetMasks = subnetMasks
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv4IncludedRoutes sets the included IPv4 routes
|
||||||
|
func SetIPv4IncludedRoutes(routes []IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv4IncludedRoutes = routes
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set IPv4 included routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddIPv4IncludedRoute(route IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
|
||||||
|
// make sure it does not already exist
|
||||||
|
for _, r := range networkSettings.IPv4IncludedRoutes {
|
||||||
|
if r == route {
|
||||||
|
logger.Info("IPv4 included route already exists: %+v", route)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route)
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Added IPv4 included route: %+v", route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RemoveIPv4IncludedRoute(route IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
routes := networkSettings.IPv4IncludedRoutes
|
||||||
|
for i, r := range routes {
|
||||||
|
if r == route {
|
||||||
|
networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...)
|
||||||
|
logger.Info("Removed IPv4 included route: %+v", route)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
incrementor++
|
||||||
|
logger.Info("IPv4 included route not found for removal: %+v", route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetIPv4ExcludedRoutes(routes []IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv4ExcludedRoutes = routes
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set IPv4 excluded routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv6Settings sets IPv6 addresses and network prefixes
|
||||||
|
func SetIPv6Settings(addresses []string, networkPrefixes []string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv6Addresses = addresses
|
||||||
|
networkSettings.IPv6NetworkPrefixes = networkPrefixes
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv6IncludedRoutes sets the included IPv6 routes
|
||||||
|
func SetIPv6IncludedRoutes(routes []IPv6Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv6IncludedRoutes = routes
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set IPv6 included routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv6ExcludedRoutes sets the excluded IPv6 routes
|
||||||
|
func SetIPv6ExcludedRoutes(routes []IPv6Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv6ExcludedRoutes = routes
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Set IPv6 excluded routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearNetworkSettings clears all network settings
|
||||||
|
func ClearNetworkSettings() {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings = NetworkSettings{}
|
||||||
|
incrementor++
|
||||||
|
logger.Info("Cleared all network settings")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetJSON() (string, error) {
|
||||||
|
networkSettingsMutex.RLock()
|
||||||
|
defer networkSettingsMutex.RUnlock()
|
||||||
|
data, err := json.MarshalIndent(networkSettings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSettings() NetworkSettings {
|
||||||
|
networkSettingsMutex.RLock()
|
||||||
|
defer networkSettingsMutex.RUnlock()
|
||||||
|
return networkSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetIncrementor() int {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
return incrementor
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user