diff --git a/network/network.go b/network/network.go new file mode 100644 index 0000000..c5d4500 --- /dev/null +++ b/network/network.go @@ -0,0 +1,165 @@ +package network + +import ( + "encoding/json" + "sync" + + "github.com/fosrl/newt/logger" +) + +// NetworkSettings represents the network configuration for the tunnel +type NetworkSettings struct { + TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"` + MTU *int `json:"mtu,omitempty"` + DNSServers []string `json:"dns_servers,omitempty"` + IPv4Addresses []string `json:"ipv4_addresses,omitempty"` + IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"` + IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"` + IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"` + IPv6Addresses []string `json:"ipv6_addresses,omitempty"` + IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"` + IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"` + IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"` +} + +// IPv4Route represents an IPv4 route +type IPv4Route struct { + DestinationAddress string `json:"destination_address"` + SubnetMask string `json:"subnet_mask,omitempty"` + GatewayAddress string `json:"gateway_address,omitempty"` + IsDefault bool `json:"is_default,omitempty"` +} + +// IPv6Route represents an IPv6 route +type IPv6Route struct { + DestinationAddress string `json:"destination_address"` + NetworkPrefixLength int `json:"network_prefix_length,omitempty"` + GatewayAddress string `json:"gateway_address,omitempty"` + IsDefault bool `json:"is_default,omitempty"` +} + +var ( + networkSettings NetworkSettings + networkSettingsMutex sync.RWMutex +) + +// SetTunnelRemoteAddress sets the tunnel remote address +func SetTunnelRemoteAddress(address string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.TunnelRemoteAddress = address + logger.Info("Set tunnel remote address: %s", address) +} + +// SetMTU sets the MTU value +func SetMTU(mtu int) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.MTU = &mtu + logger.Info("Set MTU: %d", mtu) +} + +// SetDNSServers sets the DNS servers +func SetDNSServers(servers []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.DNSServers = servers + logger.Info("Set DNS servers: %v", servers) +} + +// SetIPv4Settings sets IPv4 addresses and subnet masks +func SetIPv4Settings(addresses []string, subnetMasks []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4Addresses = addresses + networkSettings.IPv4SubnetMasks = subnetMasks + logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks) +} + +// SetIPv4IncludedRoutes sets the included IPv4 routes +func SetIPv4IncludedRoutes(routes []IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4IncludedRoutes = routes + logger.Info("Set IPv4 included routes: %d routes", len(routes)) +} + +func AddIPv4IncludedRoute(route IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + + // make sure it does not already exist + for _, r := range networkSettings.IPv4IncludedRoutes { + if r == route { + logger.Info("IPv4 included route already exists: %+v", route) + return + } + } + + networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route) + logger.Info("Added IPv4 included route: %+v", route) +} + +func RemoveIPv4IncludedRoute(route IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + routes := networkSettings.IPv4IncludedRoutes + for i, r := range routes { + if r == route { + networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...) + logger.Info("Removed IPv4 included route: %+v", route) + return + } + } + logger.Info("IPv4 included route not found for removal: %+v", route) +} + +func SetIPv4ExcludedRoutes(routes []IPv4Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv4ExcludedRoutes = routes + logger.Info("Set IPv4 excluded routes: %d routes", len(routes)) +} + +// SetIPv6Settings sets IPv6 addresses and network prefixes +func SetIPv6Settings(addresses []string, networkPrefixes []string) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6Addresses = addresses + networkSettings.IPv6NetworkPrefixes = networkPrefixes + logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes) +} + +// SetIPv6IncludedRoutes sets the included IPv6 routes +func SetIPv6IncludedRoutes(routes []IPv6Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6IncludedRoutes = routes + logger.Info("Set IPv6 included routes: %d routes", len(routes)) +} + +// SetIPv6ExcludedRoutes sets the excluded IPv6 routes +func SetIPv6ExcludedRoutes(routes []IPv6Route) { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings.IPv6ExcludedRoutes = routes + logger.Info("Set IPv6 excluded routes: %d routes", len(routes)) +} + +// ClearNetworkSettings clears all network settings +func ClearNetworkSettings() { + networkSettingsMutex.Lock() + defer networkSettingsMutex.Unlock() + networkSettings = NetworkSettings{} + logger.Info("Cleared all network settings") +} + +func GetNetworkSettingsJSON() (string, error) { + networkSettingsMutex.RLock() + defer networkSettingsMutex.RUnlock() + data, err := json.MarshalIndent(networkSettings, "", " ") + if err != nil { + return "", err + } + return string(data), nil +} diff --git a/olm/common.go b/olm/common.go index 6ebfb51..2dafe3e 100644 --- a/olm/common.go +++ b/olm/common.go @@ -3,116 +3,13 @@ package olm import ( "fmt" "net" - "os/exec" - "regexp" - "runtime" - "strconv" "strings" "time" "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/util" - "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" - "github.com/vishvananda/netlink" - "golang.zx2c4.com/wireguard/device" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -type WgData struct { - Sites []SiteConfig `json:"sites"` - TunnelIP string `json:"tunnelIP"` -} - -type SiteConfig struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access -} - -type TargetsByType struct { - UDP []string `json:"udp"` - TCP []string `json:"tcp"` -} - -type TargetData struct { - Targets []string `json:"targets"` -} - -type HolePunchMessage struct { - NewtID string `json:"newtId"` -} - -type ExitNode struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` -} - -type HolePunchData struct { - ExitNodes []ExitNode `json:"exitNodes"` -} - -type EncryptedHolePunchMessage struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` -} - -var ( - peerMonitor *peermonitor.PeerMonitor - stopHolepunch chan struct{} - stopRegister func() - stopPing chan struct{} - olmToken string - holePunchRunning bool -) - -const ( - ENV_WG_TUN_FD = "WG_TUN_FD" - ENV_WG_UAPI_FD = "WG_UAPI_FD" - ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND" -) - -// PeerAction represents a request to add, update, or remove a peer -type PeerAction struct { - Action string `json:"action"` // "add", "update", or "remove" - SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information -} - -// UpdatePeerData represents the data needed to update a peer -type UpdatePeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access -} - -// AddPeerData represents the data needed to add a peer -type AddPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access -} - -// RemovePeerData represents the data needed to remove a peer -type RemovePeerData struct { - SiteId int `json:"siteId"` -} - -type RelayPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` -} - // Helper function to format endpoints correctly func formatEndpoint(endpoint string) string { if endpoint == "" { @@ -140,21 +37,6 @@ func formatEndpoint(endpoint string) string { return endpoint } -func mapToWireGuardLogLevel(level logger.LogLevel) int { - switch level { - case logger.DEBUG: - return device.LogLevelVerbose - // case logger.INFO: - // return device.LogLevel - case logger.WARN: - return device.LogLevelError - case logger.ERROR, logger.FATAL: - return device.LogLevelSilent - default: - return device.LogLevelSilent - } -} - func sendPing(olm *websocket.Client) error { err := olm.SendMessage("olm/ping", map[string]interface{}{ "timestamp": time.Now().Unix(), @@ -192,578 +74,3 @@ func keepSendingPing(olm *websocket.Client) { } } } - -// ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { - siteHost, err := util.ResolveDomain(siteConfig.Endpoint) - if err != nil { - return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) - } - - // Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP - allowedIp := strings.Split(siteConfig.ServerIP, "/") - if len(allowedIp) > 1 { - allowedIp[1] = "32" - } else { - allowedIp = append(allowedIp, "32") - } - allowedIpStr := strings.Join(allowedIp, "/") - - // Collect all allowed IPs in a slice - var allowedIPs []string - allowedIPs = append(allowedIPs, allowedIpStr) - - // If we have anything in remoteSubnets, add those as well - if siteConfig.RemoteSubnets != "" { - // Split remote subnets by comma and add each one - remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",") - for _, subnet := range remoteSubnets { - subnet = strings.TrimSpace(subnet) - if subnet != "" { - allowedIPs = append(allowedIPs, subnet) - } - } - } - - // Construct WireGuard config for this peer - var configBuilder strings.Builder - configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String()))) - configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey))) - - // Add each allowed IP separately - for _, allowedIP := range allowedIPs { - configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) - } - - configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=1\n") - - config := configBuilder.String() - logger.Debug("Configuring peer with config: %s", config) - - err = dev.IpcSet(config) - if err != nil { - return fmt.Errorf("failed to configure WireGuard peer: %v", err) - } - - // Set up peer monitoring - if peerMonitor != nil { - monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] - monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port - logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) - - primaryRelay, err := util.ResolveDomain(endpoint) // Using global endpoint variable - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint: %v", err) - } - - wgConfig := &peermonitor.WireGuardConfig{ - SiteID: siteConfig.SiteId, - PublicKey: util.FixKey(siteConfig.PublicKey), - ServerIP: strings.Split(siteConfig.ServerIP, "/")[0], - Endpoint: siteConfig.Endpoint, - PrimaryRelay: primaryRelay, - } - - err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig) - if err != nil { - logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) - } else { - logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) - } - } - - return nil -} - -// RemovePeer removes a peer from the WireGuard device -func RemovePeer(dev *device.Device, siteId int, publicKey string) error { - // Construct WireGuard config to remove the peer - var configBuilder strings.Builder - configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) - configBuilder.WriteString("remove=true\n") - - config := configBuilder.String() - logger.Debug("Removing peer with config: %s", config) - - err := dev.IpcSet(config) - if err != nil { - return fmt.Errorf("failed to remove WireGuard peer: %v", err) - } - - // Stop monitoring this peer - if peerMonitor != nil { - peerMonitor.RemovePeer(siteId) - logger.Info("Stopped monitoring for site %d", siteId) - } - - return nil -} - -// ConfigureInterface configures a network interface with an IP address and brings it up -func ConfigureInterface(interfaceName string, wgData WgData) error { - var ipAddr string = wgData.TunnelIP - - // Parse the IP address and network - ip, ipNet, err := net.ParseCIDR(ipAddr) - if err != nil { - return fmt.Errorf("invalid IP address: %v", err) - } - - switch runtime.GOOS { - case "linux": - return configureLinux(interfaceName, ip, ipNet) - case "darwin": - return configureDarwin(interfaceName, ip, ipNet) - case "windows": - return configureWindows(interfaceName, ip, ipNet) - default: - return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) - } -} - -func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - logger.Info("Configuring Windows interface: %s", interfaceName) - - // Calculate mask string (e.g., 255.255.255.0) - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - // Set the IP address using netsh - cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", - fmt.Sprintf("name=%s", interfaceName), - "source=static", - fmt.Sprintf("addr=%s", ip.String()), - fmt.Sprintf("mask=%s", maskIP.String())) - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh command failed: %v, output: %s", err, out) - } - - // Bring up the interface if needed (in Windows, setting the IP usually brings it up) - // But we'll explicitly enable it to be sure - cmd = exec.Command("netsh", "interface", "set", "interface", - interfaceName, - "admin=enable") - - logger.Info("Running command: %v", cmd) - out, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out) - } - - // delay 2 seconds - time.Sleep(8 * time.Second) - - // Wait for the interface to be up and have the correct IP - err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) - if err != nil { - return fmt.Errorf("interface did not come up within timeout: %v", err) - } - - return nil -} - -// waitForInterfaceUp polls the network interface until it's up or times out -func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { - logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) - deadline := time.Now().Add(timeout) - pollInterval := 500 * time.Millisecond - - for time.Now().Before(deadline) { - // Check if interface exists and is up - iface, err := net.InterfaceByName(interfaceName) - if err == nil { - // Check if interface is up - if iface.Flags&net.FlagUp != 0 { - // Check if it has the expected IP - addrs, err := iface.Addrs() - if err == nil { - for _, addr := range addrs { - ipNet, ok := addr.(*net.IPNet) - if ok && ipNet.IP.Equal(expectedIP) { - logger.Info("Interface %s is up with correct IP", interfaceName) - return nil // Interface is up with correct IP - } - } - logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName) - } - } else { - logger.Info("Interface %s exists but is not up yet", interfaceName) - } - } else { - logger.Info("Interface %s not found yet: %v", interfaceName, err) - } - - // Wait before next check - time.Sleep(pollInterval) - } - - return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) -} - -func WindowsAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "windows" { - return nil - } - - var cmd *exec.Cmd - - // Parse destination to get the IP and subnet - ip, ipNet, err := net.ParseCIDR(destination) - if err != nil { - return fmt.Errorf("invalid destination address: %v", err) - } - - // Calculate the subnet mask - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("route", "add", - ip.String(), - "mask", maskIP.String(), - gateway, - "metric", "1") - } else if interfaceName != "" { - // First, get the interface index - indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces") - output, err := indexCmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to get interface index: %v, output: %s", err, output) - } - - // Parse the output to find the interface index - lines := strings.Split(string(output), "\n") - var ifIndex string - for _, line := range lines { - if strings.Contains(line, interfaceName) { - fields := strings.Fields(line) - if len(fields) > 0 { - ifIndex = fields[0] - break - } - } - } - - if ifIndex == "" { - return fmt.Errorf("could not find index for interface %s", interfaceName) - } - - // Convert to integer to validate - idx, err := strconv.Atoi(ifIndex) - if err != nil { - return fmt.Errorf("invalid interface index: %v", err) - } - - // Route via interface using the index - cmd = exec.Command("route", "add", - ip.String(), - "mask", maskIP.String(), - "0.0.0.0", - "if", strconv.Itoa(idx)) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route command failed: %v, output: %s", err, out) - } - - return nil -} - -func WindowsRemoveRoute(destination string) error { - // Parse destination to get the IP - ip, ipNet, err := net.ParseCIDR(destination) - if err != nil { - return fmt.Errorf("invalid destination address: %v", err) - } - - // Calculate the subnet mask - maskBits, _ := ipNet.Mask.Size() - mask := net.CIDRMask(maskBits, 32) - maskIP := net.IP(mask) - - cmd := exec.Command("route", "delete", - ip.String(), - "mask", maskIP.String()) - - logger.Info("Running command: %v", cmd) - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route delete command failed: %v, output: %s", err, out) - } - - return nil -} - -func findUnusedUTUN() (string, error) { - ifaces, err := net.Interfaces() - if err != nil { - return "", fmt.Errorf("failed to list interfaces: %v", err) - } - used := make(map[int]bool) - re := regexp.MustCompile(`^utun(\d+)$`) - for _, iface := range ifaces { - if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { - if num, err := strconv.Atoi(matches[1]); err == nil { - used[num] = true - } - } - } - // Try utun0 up to utun255. - for i := 0; i < 256; i++ { - if !used[i] { - return fmt.Sprintf("utun%d", i), nil - } - } - return "", fmt.Errorf("no unused utun interface found") -} - -func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - logger.Info("Configuring darwin interface: %s", interfaceName) - - prefix, _ := ipNet.Mask.Size() - ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) - - cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) - } - - // Bring up the interface - cmd = exec.Command("ifconfig", interfaceName, "up") - logger.Info("Running command: %v", cmd) - - out, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) - } - - return nil -} - -func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { - // Get the interface - link, err := netlink.LinkByName(interfaceName) - if err != nil { - return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) - } - - // Create the IP address attributes - addr := &netlink.Addr{ - IPNet: &net.IPNet{ - IP: ip, - Mask: ipNet.Mask, - }, - } - - // Add the IP address to the interface - if err := netlink.AddrAdd(link, addr); err != nil { - return fmt.Errorf("failed to add IP address: %v", err) - } - - // Bring up the interface - if err := netlink.LinkSetUp(link); err != nil { - return fmt.Errorf("failed to bring up interface: %v", err) - } - - return nil -} - -func DarwinAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "darwin" { - return nil - } - - var cmd *exec.Cmd - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) - } else if interfaceName != "" { - // Route via interface - cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route command failed: %v, output: %s", err, out) - } - - return nil -} - -func DarwinRemoveRoute(destination string) error { - if runtime.GOOS != "darwin" { - return nil - } - - cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("route delete command failed: %v, output: %s", err, out) - } - - return nil -} - -func LinuxAddRoute(destination string, gateway string, interfaceName string) error { - if runtime.GOOS != "linux" { - return nil - } - - var cmd *exec.Cmd - - if gateway != "" { - // Route with specific gateway - cmd = exec.Command("ip", "route", "add", destination, "via", gateway) - } else if interfaceName != "" { - // Route via interface - cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName) - } else { - return fmt.Errorf("either gateway or interface must be specified") - } - - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ip route command failed: %v, output: %s", err, out) - } - - return nil -} - -func LinuxRemoveRoute(destination string) error { - if runtime.GOOS != "linux" { - return nil - } - - cmd := exec.Command("ip", "route", "del", destination) - logger.Info("Running command: %v", cmd) - - out, err := cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out) - } - - return nil -} - -// addRouteForServerIP adds an OS-specific route for the server IP -func addRouteForServerIP(serverIP, interfaceName string) error { - if runtime.GOOS == "darwin" { - return DarwinAddRoute(serverIP, "", interfaceName) - } - // else if runtime.GOOS == "windows" { - // return WindowsAddRoute(serverIP, "", interfaceName) - // } else if runtime.GOOS == "linux" { - // return LinuxAddRoute(serverIP, "", interfaceName) - // } - return nil -} - -// removeRouteForServerIP removes an OS-specific route for the server IP -func removeRouteForServerIP(serverIP string) error { - if runtime.GOOS == "darwin" { - return DarwinRemoveRoute(serverIP) - } - // else if runtime.GOOS == "windows" { - // return WindowsRemoveRoute(serverIP) - // } else if runtime.GOOS == "linux" { - // return LinuxRemoveRoute(serverIP) - // } - return nil -} - -// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets -func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error { - if remoteSubnets == "" { - return nil - } - - // Split remote subnets by comma and add routes for each one - subnets := strings.Split(remoteSubnets, ",") - for _, subnet := range subnets { - subnet = strings.TrimSpace(subnet) - if subnet == "" { - continue - } - - // Add route based on operating system - if runtime.GOOS == "darwin" { - if err := DarwinAddRoute(subnet, "", interfaceName); err != nil { - logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "windows" { - if err := WindowsAddRoute(subnet, "", interfaceName); err != nil { - logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "linux" { - if err := LinuxAddRoute(subnet, "", interfaceName); err != nil { - logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err) - return err - } - } - - logger.Info("Added route for remote subnet: %s", subnet) - } - return nil -} - -// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets -func removeRoutesForRemoteSubnets(remoteSubnets string) error { - if remoteSubnets == "" { - return nil - } - - // Split remote subnets by comma and remove routes for each one - subnets := strings.Split(remoteSubnets, ",") - for _, subnet := range subnets { - subnet = strings.TrimSpace(subnet) - if subnet == "" { - continue - } - - // Remove route based on operating system - if runtime.GOOS == "darwin" { - if err := DarwinRemoveRoute(subnet); err != nil { - logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "windows" { - if err := WindowsRemoveRoute(subnet); err != nil { - logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err) - return err - } - } else if runtime.GOOS == "linux" { - if err := LinuxRemoveRoute(subnet); err != nil { - logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err) - return err - } - } - - logger.Info("Removed route for remote subnet: %s", subnet) - } - return nil -} diff --git a/olm/interface.go b/olm/interface.go new file mode 100644 index 0000000..ab4b4fb --- /dev/null +++ b/olm/interface.go @@ -0,0 +1,213 @@ +package olm + +import ( + "fmt" + "net" + "os/exec" + "regexp" + "runtime" + "strconv" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/network" + "github.com/vishvananda/netlink" +) + +// ConfigureInterface configures a network interface with an IP address and brings it up +func ConfigureInterface(interfaceName string, wgData WgData) error { + if interfaceName == "" { + return nil + } + + var ipAddr string = wgData.TunnelIP + + // Parse the IP address and network + ip, ipNet, err := net.ParseCIDR(ipAddr) + if err != nil { + return fmt.Errorf("invalid IP address: %v", err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ipNet.IP.String() + + // network.SetTunnelRemoteAddress() // what does this do? + network.SetIPv4Settings([]string{destinationAddress}, []string{mask}) + apiServer.SetTunnelIP(destinationAddress) + + if interfaceName == "" { + return nil + } + + switch runtime.GOOS { + case "linux": + return configureLinux(interfaceName, ip, ipNet) + case "darwin": + return configureDarwin(interfaceName, ip, ipNet) + case "windows": + return configureWindows(interfaceName, ip, ipNet) + default: + return fmt.Errorf("unsupported operating system: %s", runtime.GOOS) + } +} + +func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring Windows interface: %s", interfaceName) + + // Calculate mask string (e.g., 255.255.255.0) + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + // Set the IP address using netsh + cmd := exec.Command("netsh", "interface", "ipv4", "set", "address", + fmt.Sprintf("name=%s", interfaceName), + "source=static", + fmt.Sprintf("addr=%s", ip.String()), + fmt.Sprintf("mask=%s", maskIP.String())) + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("netsh command failed: %v, output: %s", err, out) + } + + // Bring up the interface if needed (in Windows, setting the IP usually brings it up) + // But we'll explicitly enable it to be sure + cmd = exec.Command("netsh", "interface", "set", "interface", + interfaceName, + "admin=enable") + + logger.Info("Running command: %v", cmd) + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out) + } + + // delay 2 seconds + time.Sleep(8 * time.Second) + + // Wait for the interface to be up and have the correct IP + err = waitForInterfaceUp(interfaceName, ip, 30*time.Second) + if err != nil { + return fmt.Errorf("interface did not come up within timeout: %v", err) + } + + return nil +} + +// waitForInterfaceUp polls the network interface until it's up or times out +func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error { + logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP) + deadline := time.Now().Add(timeout) + pollInterval := 500 * time.Millisecond + + for time.Now().Before(deadline) { + // Check if interface exists and is up + iface, err := net.InterfaceByName(interfaceName) + if err == nil { + // Check if interface is up + if iface.Flags&net.FlagUp != 0 { + // Check if it has the expected IP + addrs, err := iface.Addrs() + if err == nil { + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if ok && ipNet.IP.Equal(expectedIP) { + logger.Info("Interface %s is up with correct IP", interfaceName) + return nil // Interface is up with correct IP + } + } + logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName) + } + } else { + logger.Info("Interface %s exists but is not up yet", interfaceName) + } + } else { + logger.Info("Interface %s not found yet: %v", interfaceName, err) + } + + // Wait before next check + time.Sleep(pollInterval) + } + + return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP) +} + +func findUnusedUTUN() (string, error) { + ifaces, err := net.Interfaces() + if err != nil { + return "", fmt.Errorf("failed to list interfaces: %v", err) + } + used := make(map[int]bool) + re := regexp.MustCompile(`^utun(\d+)$`) + for _, iface := range ifaces { + if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 { + if num, err := strconv.Atoi(matches[1]); err == nil { + used[num] = true + } + } + } + // Try utun0 up to utun255. + for i := 0; i < 256; i++ { + if !used[i] { + return fmt.Sprintf("utun%d", i), nil + } + } + return "", fmt.Errorf("no unused utun interface found") +} + +func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + logger.Info("Configuring darwin interface: %s", interfaceName) + + prefix, _ := ipNet.Mask.Size() + ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix) + + cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias") + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out) + } + + // Bring up the interface + cmd = exec.Command("ifconfig", interfaceName, "up") + logger.Info("Running command: %v", cmd) + + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out) + } + + return nil +} + +func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error { + // Get the interface + link, err := netlink.LinkByName(interfaceName) + if err != nil { + return fmt.Errorf("failed to get interface %s: %v", interfaceName, err) + } + + // Create the IP address attributes + addr := &netlink.Addr{ + IPNet: &net.IPNet{ + IP: ip, + Mask: ipNet.Mask, + }, + } + + // Add the IP address to the interface + if err := netlink.AddrAdd(link, addr); err != nil { + return fmt.Errorf("failed to add IP address: %v", err) + } + + // Bring up the interface + if err := netlink.LinkSetUp(link); err != nil { + return fmt.Errorf("failed to bring up interface: %v", err) + } + + return nil +} diff --git a/olm/olm.go b/olm/olm.go index af68487..960d9cf 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -4,9 +4,7 @@ import ( "context" "encoding/json" "net" - "os" "runtime" - "strconv" "time" "github.com/fosrl/newt/bind" @@ -15,6 +13,7 @@ import ( "github.com/fosrl/newt/updates" "github.com/fosrl/newt/util" "github.com/fosrl/olm/api" + "github.com/fosrl/olm/network" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" @@ -57,7 +56,8 @@ type Config struct { OrgID string // DoNotCreateNewClient bool - FileDescriptorTun uint32 + FileDescriptorTun uint32 + FileDescriptorUAPI uint32 } var ( @@ -82,6 +82,7 @@ func Run(ctx context.Context, config Config) { defer cancel() logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) + network.SetMTU(config.MTU) if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil { logger.Debug("Failed to check for updates: %v", err) @@ -371,14 +372,14 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, if config.FileDescriptorTun != 0 { return createTUNFromFD(config.FileDescriptorTun, config.MTU) } + var ifName = interfaceName if runtime.GOOS == "darwin" { // this is if we dont pass a fd - interfaceName, err := findUnusedUTUN() + ifName, err = findUnusedUTUN() if err != nil { return nil, err } - return tun.CreateTUN(interfaceName, config.MTU) } - return tun.CreateTUN(interfaceName, config.MTU) + return tun.CreateTUN(ifName, config.MTU) }() if err != nil { @@ -386,45 +387,47 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, return } - if realInterfaceName, err2 := tdev.Name(); err2 == nil { - interfaceName = realInterfaceName - } - - fileUAPI, err := func() (*os.File, error) { - if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" { - fd, err := strconv.ParseUint(uapiFdStr, 10, 32) - if err != nil { - return nil, err - } - return os.NewFile(uintptr(fd), ""), nil + if config.FileDescriptorTun == 0 { + if realInterfaceName, err2 := tdev.Name(); err2 == nil { + interfaceName = realInterfaceName } - return uapiOpen(interfaceName) - }() - if err != nil { - logger.Error("UAPI listen error: %v", err) - os.Exit(1) - return } - dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) + // fileUAPI, err := func() (*os.File, error) { + // if config.FileDescriptorUAPI != 0 { + // fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) + // if err != nil { + // return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err) + // } + // return os.NewFile(uintptr(fd), ""), nil + // } + // return uapiOpen(interfaceName) + // }() + // if err != nil { + // logger.Error("UAPI listen error: %v", err) + // os.Exit(1) + // return + // } - uapiListener, err = uapiListen(interfaceName, fileUAPI) - if err != nil { - logger.Error("Failed to listen on uapi socket: %v", err) - os.Exit(1) - } + dev = device.NewDevice(tdev, sharedBind, device.NewLogger(util.MapToWireGuardLogLevel(loggerLevel), "wireguard: ")) - go func() { - for { - conn, err := uapiListener.Accept() - if err != nil { + // uapiListener, err = uapiListen(interfaceName, fileUAPI) + // if err != nil { + // logger.Error("Failed to listen on uapi socket: %v", err) + // os.Exit(1) + // } - return - } - go dev.IpcHandle(conn) - } - }() - logger.Info("UAPI listener started") + // go func() { + // for { + // conn, err := uapiListener.Accept() + // if err != nil { + + // return + // } + // go dev.IpcHandle(conn) + // } + // }() + // logger.Info("UAPI listener started") if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) @@ -432,7 +435,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) } - apiServer.SetTunnelIP(wgData.TunnelIP) peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { @@ -476,10 +478,10 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, logger.Error("Failed to add route for peer: %v", err) return } - if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { - logger.Error("Failed to add routes for remote subnets: %v", err) - return - } + // if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { + // logger.Error("Failed to add routes for remote subnets: %v", err) + // return + // } logger.Info("Configured peer %s", site.PublicKey) } @@ -671,7 +673,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, } // Remove route for the peer - err = removeRouteForServerIP(peerToRemove.ServerIP) + err = removeRouteForServerIP(peerToRemove.ServerIP, interfaceName) if err != nil { logger.Error("Failed to remove route for peer: %v", err) return diff --git a/olm/peer.go b/olm/peer.go new file mode 100644 index 0000000..febf5bd --- /dev/null +++ b/olm/peer.go @@ -0,0 +1,121 @@ +package olm + +import ( + "fmt" + "net" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/peermonitor" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// ConfigurePeer sets up or updates a peer within the WireGuard device +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { + siteHost, err := util.ResolveDomain(siteConfig.Endpoint) + if err != nil { + return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) + } + + // Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP + allowedIp := strings.Split(siteConfig.ServerIP, "/") + if len(allowedIp) > 1 { + allowedIp[1] = "32" + } else { + allowedIp = append(allowedIp, "32") + } + allowedIpStr := strings.Join(allowedIp, "/") + + // Collect all allowed IPs in a slice + var allowedIPs []string + allowedIPs = append(allowedIPs, allowedIpStr) + + // If we have anything in remoteSubnets, add those as well + if siteConfig.RemoteSubnets != "" { + // Split remote subnets by comma and add each one + remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",") + for _, subnet := range remoteSubnets { + subnet = strings.TrimSpace(subnet) + if subnet != "" { + allowedIPs = append(allowedIPs, subnet) + } + } + } + + // Construct WireGuard config for this peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String()))) + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey))) + + // Add each allowed IP separately + for _, allowedIP := range allowedIPs { + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) + } + + configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) + configBuilder.WriteString("persistent_keepalive_interval=1\n") + + config := configBuilder.String() + logger.Debug("Configuring peer with config: %s", config) + + err = dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to configure WireGuard peer: %v", err) + } + + // Set up peer monitoring + if peerMonitor != nil { + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] + monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port + logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) + + primaryRelay, err := util.ResolveDomain(endpoint) // Using global endpoint variable + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + wgConfig := &peermonitor.WireGuardConfig{ + SiteID: siteConfig.SiteId, + PublicKey: util.FixKey(siteConfig.PublicKey), + ServerIP: strings.Split(siteConfig.ServerIP, "/")[0], + Endpoint: siteConfig.Endpoint, + PrimaryRelay: primaryRelay, + } + + err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig) + if err != nil { + logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) + } else { + logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) + } + } + + return nil +} + +// RemovePeer removes a peer from the WireGuard device +func RemovePeer(dev *device.Device, siteId int, publicKey string) error { + // Construct WireGuard config to remove the peer + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("remove=true\n") + + config := configBuilder.String() + logger.Debug("Removing peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to remove WireGuard peer: %v", err) + } + + // Stop monitoring this peer + if peerMonitor != nil { + peerMonitor.RemovePeer(siteId) + logger.Info("Stopped monitoring for site %d", siteId) + } + + return nil +} diff --git a/olm/route.go b/olm/route.go new file mode 100644 index 0000000..cc991fc --- /dev/null +++ b/olm/route.go @@ -0,0 +1,358 @@ +package olm + +import ( + "fmt" + "net" + "os/exec" + "runtime" + "strconv" + "strings" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/network" +) + +func DarwinAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "darwin" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func DarwinRemoveRoute(destination string) error { + if runtime.GOOS != "darwin" { + return nil + } + + cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +func LinuxAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "linux" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("ip", "route", "add", destination, "via", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ip route command failed: %v, output: %s", err, out) + } + + return nil +} + +func LinuxRemoveRoute(destination string) error { + if runtime.GOOS != "linux" { + return nil + } + + cmd := exec.Command("ip", "route", "del", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +func WindowsAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "windows" { + return nil + } + + var cmd *exec.Cmd + + // Parse destination to get the IP and subnet + ip, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Calculate the subnet mask + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("route", "add", + ip.String(), + "mask", maskIP.String(), + gateway, + "metric", "1") + } else if interfaceName != "" { + // First, get the interface index + indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces") + output, err := indexCmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to get interface index: %v, output: %s", err, output) + } + + // Parse the output to find the interface index + lines := strings.Split(string(output), "\n") + var ifIndex string + for _, line := range lines { + if strings.Contains(line, interfaceName) { + fields := strings.Fields(line) + if len(fields) > 0 { + ifIndex = fields[0] + break + } + } + } + + if ifIndex == "" { + return fmt.Errorf("could not find index for interface %s", interfaceName) + } + + // Convert to integer to validate + idx, err := strconv.Atoi(ifIndex) + if err != nil { + return fmt.Errorf("invalid interface index: %v", err) + } + + // Route via interface using the index + cmd = exec.Command("route", "add", + ip.String(), + "mask", maskIP.String(), + "0.0.0.0", + "if", strconv.Itoa(idx)) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route command failed: %v, output: %s", err, out) + } + + return nil +} + +func WindowsRemoveRoute(destination string) error { + // Parse destination to get the IP + ip, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("invalid destination address: %v", err) + } + + // Calculate the subnet mask + maskBits, _ := ipNet.Mask.Size() + mask := net.CIDRMask(maskBits, 32) + maskIP := net.IP(mask) + + cmd := exec.Command("route", "delete", + ip.String(), + "mask", maskIP.String()) + + logger.Info("Running command: %v", cmd) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +// addRouteForServerIP adds an OS-specific route for the server IP +func addRouteForServerIP(serverIP, interfaceName string) error { + if err := addRouteForNetworkConfig(serverIP); err != nil { + return err + } + if interfaceName == "" { + return nil + } + if runtime.GOOS == "darwin" { + return DarwinAddRoute(serverIP, "", interfaceName) + } + // else if runtime.GOOS == "windows" { + // return WindowsAddRoute(serverIP, "", interfaceName) + // } else if runtime.GOOS == "linux" { + // return LinuxAddRoute(serverIP, "", interfaceName) + // } + return nil +} + +// removeRouteForServerIP removes an OS-specific route for the server IP +func removeRouteForServerIP(serverIP string, interfaceName string) error { + if err := removeRouteForNetworkConfig(serverIP); err != nil { + return err + } + if interfaceName == "" { + return nil + } + if runtime.GOOS == "darwin" { + return DarwinRemoveRoute(serverIP) + } + // else if runtime.GOOS == "windows" { + // return WindowsRemoveRoute(serverIP) + // } else if runtime.GOOS == "linux" { + // return LinuxRemoveRoute(serverIP) + // } + return nil +} + +func addRouteForNetworkConfig(destination string) error { + // Parse the subnet to extract IP and mask + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("failed to parse subnet %s: %v", destination, err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ipNet.IP.String() + + network.AddIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + + return nil +} + +func removeRouteForNetworkConfig(destination string) error { + // Parse the subnet to extract IP and mask + _, ipNet, err := net.ParseCIDR(destination) + if err != nil { + return fmt.Errorf("failed to parse subnet %s: %v", destination, err) + } + + // Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0) + mask := net.IP(ipNet.Mask).String() + destinationAddress := ipNet.IP.String() + + network.RemoveIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask}) + + return nil +} + +// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets +func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error { + if remoteSubnets == "" { + return nil + } + + // Split remote subnets by comma and add routes for each one + subnets := strings.Split(remoteSubnets, ",") + for _, subnet := range subnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + if err := addRouteForNetworkConfig(subnet); err != nil { + logger.Error("Failed to add network config for subnet %s: %v", subnet, err) + continue + } + + // Add route based on operating system + if interfaceName == "" { + continue + } + + if runtime.GOOS == "darwin" { + if err := DarwinAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Added route for remote subnet: %s", subnet) + } + return nil +} + +// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets +func removeRoutesForRemoteSubnets(remoteSubnets string) error { + if remoteSubnets == "" { + return nil + } + + // Split remote subnets by comma and remove routes for each one + subnets := strings.Split(remoteSubnets, ",") + for _, subnet := range subnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + if err := removeRouteForNetworkConfig(subnet); err != nil { + logger.Error("Failed to remove network config for subnet %s: %v", subnet, err) + continue + } + + // Remove route based on operating system + if runtime.GOOS == "darwin" { + if err := DarwinRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Removed route for remote subnet: %s", subnet) + } + + return nil +} diff --git a/olm/types.go b/olm/types.go new file mode 100644 index 0000000..192f7fe --- /dev/null +++ b/olm/types.go @@ -0,0 +1,91 @@ +package olm + +import "github.com/fosrl/olm/peermonitor" + +type WgData struct { + Sites []SiteConfig `json:"sites"` + TunnelIP string `json:"tunnelIP"` +} + +type SiteConfig struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access +} + +type TargetsByType struct { + UDP []string `json:"udp"` + TCP []string `json:"tcp"` +} + +type TargetData struct { + Targets []string `json:"targets"` +} + +type HolePunchMessage struct { + NewtID string `json:"newtId"` +} + +type ExitNode struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + +type HolePunchData struct { + ExitNodes []ExitNode `json:"exitNodes"` +} + +type EncryptedHolePunchMessage struct { + EphemeralPublicKey string `json:"ephemeralPublicKey"` + Nonce []byte `json:"nonce"` + Ciphertext []byte `json:"ciphertext"` +} + +var ( + peerMonitor *peermonitor.PeerMonitor + stopHolepunch chan struct{} + stopRegister func() + stopPing chan struct{} + olmToken string + holePunchRunning bool +) + +// PeerAction represents a request to add, update, or remove a peer +type PeerAction struct { + Action string `json:"action"` // "add", "update", or "remove" + SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information +} + +// UpdatePeerData represents the data needed to update a peer +type UpdatePeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access +} + +// AddPeerData represents the data needed to add a peer +type AddPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access +} + +// RemovePeerData represents the data needed to remove a peer +type RemovePeerData struct { + SiteId int `json:"siteId"` +} + +type RelayPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} diff --git a/olm/unix.go b/olm/unix.go index 5f5cf0e..ffdf7e9 100644 --- a/olm/unix.go +++ b/olm/unix.go @@ -6,20 +6,26 @@ import ( "net" "os" + "github.com/fosrl/newt/logger" "golang.org/x/sys/unix" "golang.zx2c4.com/wireguard/ipc" "golang.zx2c4.com/wireguard/tun" ) func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) { - err := unix.SetNonblock(int(tunFd), true) + dupTunFd, err := unix.Dup(int(tunFd)) + if err != nil { + logger.Error("Unable to dup tun fd: %v", err) + return nil, err + } + err = unix.SetNonblock(dupTunFd, true) if err != nil { return nil, err } - file := os.NewFile(uintptr(tunFd), "") - return tun.CreateTUNFromFile(file, mtuInt) + return tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), mtuInt) } + func uapiOpen(interfaceName string) (*os.File, error) { return ipc.UAPIOpen(interfaceName) }