mirror of
https://github.com/fosrl/olm.git
synced 2026-03-05 10:16:46 +00:00
165
network/network.go
Normal file
165
network/network.go
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
package network
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NetworkSettings represents the network configuration for the tunnel
|
||||||
|
type NetworkSettings struct {
|
||||||
|
TunnelRemoteAddress string `json:"tunnel_remote_address,omitempty"`
|
||||||
|
MTU *int `json:"mtu,omitempty"`
|
||||||
|
DNSServers []string `json:"dns_servers,omitempty"`
|
||||||
|
IPv4Addresses []string `json:"ipv4_addresses,omitempty"`
|
||||||
|
IPv4SubnetMasks []string `json:"ipv4_subnet_masks,omitempty"`
|
||||||
|
IPv4IncludedRoutes []IPv4Route `json:"ipv4_included_routes,omitempty"`
|
||||||
|
IPv4ExcludedRoutes []IPv4Route `json:"ipv4_excluded_routes,omitempty"`
|
||||||
|
IPv6Addresses []string `json:"ipv6_addresses,omitempty"`
|
||||||
|
IPv6NetworkPrefixes []string `json:"ipv6_network_prefixes,omitempty"`
|
||||||
|
IPv6IncludedRoutes []IPv6Route `json:"ipv6_included_routes,omitempty"`
|
||||||
|
IPv6ExcludedRoutes []IPv6Route `json:"ipv6_excluded_routes,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv4Route represents an IPv4 route
|
||||||
|
type IPv4Route struct {
|
||||||
|
DestinationAddress string `json:"destination_address"`
|
||||||
|
SubnetMask string `json:"subnet_mask,omitempty"`
|
||||||
|
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||||
|
IsDefault bool `json:"is_default,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// IPv6Route represents an IPv6 route
|
||||||
|
type IPv6Route struct {
|
||||||
|
DestinationAddress string `json:"destination_address"`
|
||||||
|
NetworkPrefixLength int `json:"network_prefix_length,omitempty"`
|
||||||
|
GatewayAddress string `json:"gateway_address,omitempty"`
|
||||||
|
IsDefault bool `json:"is_default,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
networkSettings NetworkSettings
|
||||||
|
networkSettingsMutex sync.RWMutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetTunnelRemoteAddress sets the tunnel remote address
|
||||||
|
func SetTunnelRemoteAddress(address string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.TunnelRemoteAddress = address
|
||||||
|
logger.Info("Set tunnel remote address: %s", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetMTU sets the MTU value
|
||||||
|
func SetMTU(mtu int) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.MTU = &mtu
|
||||||
|
logger.Info("Set MTU: %d", mtu)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDNSServers sets the DNS servers
|
||||||
|
func SetDNSServers(servers []string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.DNSServers = servers
|
||||||
|
logger.Info("Set DNS servers: %v", servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv4Settings sets IPv4 addresses and subnet masks
|
||||||
|
func SetIPv4Settings(addresses []string, subnetMasks []string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv4Addresses = addresses
|
||||||
|
networkSettings.IPv4SubnetMasks = subnetMasks
|
||||||
|
logger.Info("Set IPv4 addresses: %v, subnet masks: %v", addresses, subnetMasks)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv4IncludedRoutes sets the included IPv4 routes
|
||||||
|
func SetIPv4IncludedRoutes(routes []IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv4IncludedRoutes = routes
|
||||||
|
logger.Info("Set IPv4 included routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddIPv4IncludedRoute(route IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
|
||||||
|
// make sure it does not already exist
|
||||||
|
for _, r := range networkSettings.IPv4IncludedRoutes {
|
||||||
|
if r == route {
|
||||||
|
logger.Info("IPv4 included route already exists: %+v", route)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
networkSettings.IPv4IncludedRoutes = append(networkSettings.IPv4IncludedRoutes, route)
|
||||||
|
logger.Info("Added IPv4 included route: %+v", route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func RemoveIPv4IncludedRoute(route IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
routes := networkSettings.IPv4IncludedRoutes
|
||||||
|
for i, r := range routes {
|
||||||
|
if r == route {
|
||||||
|
networkSettings.IPv4IncludedRoutes = append(routes[:i], routes[i+1:]...)
|
||||||
|
logger.Info("Removed IPv4 included route: %+v", route)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Info("IPv4 included route not found for removal: %+v", route)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetIPv4ExcludedRoutes(routes []IPv4Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv4ExcludedRoutes = routes
|
||||||
|
logger.Info("Set IPv4 excluded routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv6Settings sets IPv6 addresses and network prefixes
|
||||||
|
func SetIPv6Settings(addresses []string, networkPrefixes []string) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv6Addresses = addresses
|
||||||
|
networkSettings.IPv6NetworkPrefixes = networkPrefixes
|
||||||
|
logger.Info("Set IPv6 addresses: %v, network prefixes: %v", addresses, networkPrefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv6IncludedRoutes sets the included IPv6 routes
|
||||||
|
func SetIPv6IncludedRoutes(routes []IPv6Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv6IncludedRoutes = routes
|
||||||
|
logger.Info("Set IPv6 included routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIPv6ExcludedRoutes sets the excluded IPv6 routes
|
||||||
|
func SetIPv6ExcludedRoutes(routes []IPv6Route) {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings.IPv6ExcludedRoutes = routes
|
||||||
|
logger.Info("Set IPv6 excluded routes: %d routes", len(routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearNetworkSettings clears all network settings
|
||||||
|
func ClearNetworkSettings() {
|
||||||
|
networkSettingsMutex.Lock()
|
||||||
|
defer networkSettingsMutex.Unlock()
|
||||||
|
networkSettings = NetworkSettings{}
|
||||||
|
logger.Info("Cleared all network settings")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetNetworkSettingsJSON() (string, error) {
|
||||||
|
networkSettingsMutex.RLock()
|
||||||
|
defer networkSettingsMutex.RUnlock()
|
||||||
|
data, err := json.MarshalIndent(networkSettings, "", " ")
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(data), nil
|
||||||
|
}
|
||||||
693
olm/common.go
693
olm/common.go
@@ -3,116 +3,13 @@ package olm
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"os/exec"
|
|
||||||
"regexp"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/util"
|
|
||||||
"github.com/fosrl/olm/peermonitor"
|
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
"golang.zx2c4.com/wireguard/device"
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type WgData struct {
|
|
||||||
Sites []SiteConfig `json:"sites"`
|
|
||||||
TunnelIP string `json:"tunnelIP"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type SiteConfig struct {
|
|
||||||
SiteId int `json:"siteId"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
ServerIP string `json:"serverIP"`
|
|
||||||
ServerPort uint16 `json:"serverPort"`
|
|
||||||
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
|
||||||
}
|
|
||||||
|
|
||||||
type TargetsByType struct {
|
|
||||||
UDP []string `json:"udp"`
|
|
||||||
TCP []string `json:"tcp"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type TargetData struct {
|
|
||||||
Targets []string `json:"targets"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HolePunchMessage struct {
|
|
||||||
NewtID string `json:"newtId"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type ExitNode struct {
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type HolePunchData struct {
|
|
||||||
ExitNodes []ExitNode `json:"exitNodes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type EncryptedHolePunchMessage struct {
|
|
||||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
|
||||||
Nonce []byte `json:"nonce"`
|
|
||||||
Ciphertext []byte `json:"ciphertext"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
peerMonitor *peermonitor.PeerMonitor
|
|
||||||
stopHolepunch chan struct{}
|
|
||||||
stopRegister func()
|
|
||||||
stopPing chan struct{}
|
|
||||||
olmToken string
|
|
||||||
holePunchRunning bool
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ENV_WG_TUN_FD = "WG_TUN_FD"
|
|
||||||
ENV_WG_UAPI_FD = "WG_UAPI_FD"
|
|
||||||
ENV_WG_PROCESS_FOREGROUND = "WG_PROCESS_FOREGROUND"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PeerAction represents a request to add, update, or remove a peer
|
|
||||||
type PeerAction struct {
|
|
||||||
Action string `json:"action"` // "add", "update", or "remove"
|
|
||||||
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerData represents the data needed to update a peer
|
|
||||||
type UpdatePeerData struct {
|
|
||||||
SiteId int `json:"siteId"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
ServerIP string `json:"serverIP"`
|
|
||||||
ServerPort uint16 `json:"serverPort"`
|
|
||||||
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPeerData represents the data needed to add a peer
|
|
||||||
type AddPeerData struct {
|
|
||||||
SiteId int `json:"siteId"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
ServerIP string `json:"serverIP"`
|
|
||||||
ServerPort uint16 `json:"serverPort"`
|
|
||||||
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePeerData represents the data needed to remove a peer
|
|
||||||
type RemovePeerData struct {
|
|
||||||
SiteId int `json:"siteId"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type RelayPeerData struct {
|
|
||||||
SiteId int `json:"siteId"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Helper function to format endpoints correctly
|
// Helper function to format endpoints correctly
|
||||||
func formatEndpoint(endpoint string) string {
|
func formatEndpoint(endpoint string) string {
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
@@ -140,21 +37,6 @@ func formatEndpoint(endpoint string) string {
|
|||||||
return endpoint
|
return endpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
|
||||||
switch level {
|
|
||||||
case logger.DEBUG:
|
|
||||||
return device.LogLevelVerbose
|
|
||||||
// case logger.INFO:
|
|
||||||
// return device.LogLevel
|
|
||||||
case logger.WARN:
|
|
||||||
return device.LogLevelError
|
|
||||||
case logger.ERROR, logger.FATAL:
|
|
||||||
return device.LogLevelSilent
|
|
||||||
default:
|
|
||||||
return device.LogLevelSilent
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func sendPing(olm *websocket.Client) error {
|
func sendPing(olm *websocket.Client) error {
|
||||||
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
err := olm.SendMessage("olm/ping", map[string]interface{}{
|
||||||
"timestamp": time.Now().Unix(),
|
"timestamp": time.Now().Unix(),
|
||||||
@@ -192,578 +74,3 @@ func keepSendingPing(olm *websocket.Client) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
|
||||||
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
|
||||||
siteHost, err := util.ResolveDomain(siteConfig.Endpoint)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
|
||||||
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
|
||||||
if len(allowedIp) > 1 {
|
|
||||||
allowedIp[1] = "32"
|
|
||||||
} else {
|
|
||||||
allowedIp = append(allowedIp, "32")
|
|
||||||
}
|
|
||||||
allowedIpStr := strings.Join(allowedIp, "/")
|
|
||||||
|
|
||||||
// Collect all allowed IPs in a slice
|
|
||||||
var allowedIPs []string
|
|
||||||
allowedIPs = append(allowedIPs, allowedIpStr)
|
|
||||||
|
|
||||||
// If we have anything in remoteSubnets, add those as well
|
|
||||||
if siteConfig.RemoteSubnets != "" {
|
|
||||||
// Split remote subnets by comma and add each one
|
|
||||||
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
|
|
||||||
for _, subnet := range remoteSubnets {
|
|
||||||
subnet = strings.TrimSpace(subnet)
|
|
||||||
if subnet != "" {
|
|
||||||
allowedIPs = append(allowedIPs, subnet)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Construct WireGuard config for this peer
|
|
||||||
var configBuilder strings.Builder
|
|
||||||
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String())))
|
|
||||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey)))
|
|
||||||
|
|
||||||
// Add each allowed IP separately
|
|
||||||
for _, allowedIP := range allowedIPs {
|
|
||||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
|
||||||
}
|
|
||||||
|
|
||||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
|
||||||
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
|
||||||
|
|
||||||
config := configBuilder.String()
|
|
||||||
logger.Debug("Configuring peer with config: %s", config)
|
|
||||||
|
|
||||||
err = dev.IpcSet(config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up peer monitoring
|
|
||||||
if peerMonitor != nil {
|
|
||||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
|
||||||
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
|
||||||
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
|
||||||
|
|
||||||
primaryRelay, err := util.ResolveDomain(endpoint) // Using global endpoint variable
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
wgConfig := &peermonitor.WireGuardConfig{
|
|
||||||
SiteID: siteConfig.SiteId,
|
|
||||||
PublicKey: util.FixKey(siteConfig.PublicKey),
|
|
||||||
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
|
|
||||||
Endpoint: siteConfig.Endpoint,
|
|
||||||
PrimaryRelay: primaryRelay,
|
|
||||||
}
|
|
||||||
|
|
||||||
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePeer removes a peer from the WireGuard device
|
|
||||||
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
|
||||||
// Construct WireGuard config to remove the peer
|
|
||||||
var configBuilder strings.Builder
|
|
||||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
|
||||||
configBuilder.WriteString("remove=true\n")
|
|
||||||
|
|
||||||
config := configBuilder.String()
|
|
||||||
logger.Debug("Removing peer with config: %s", config)
|
|
||||||
|
|
||||||
err := dev.IpcSet(config)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop monitoring this peer
|
|
||||||
if peerMonitor != nil {
|
|
||||||
peerMonitor.RemovePeer(siteId)
|
|
||||||
logger.Info("Stopped monitoring for site %d", siteId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConfigureInterface configures a network interface with an IP address and brings it up
|
|
||||||
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
|
||||||
var ipAddr string = wgData.TunnelIP
|
|
||||||
|
|
||||||
// Parse the IP address and network
|
|
||||||
ip, ipNet, err := net.ParseCIDR(ipAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "linux":
|
|
||||||
return configureLinux(interfaceName, ip, ipNet)
|
|
||||||
case "darwin":
|
|
||||||
return configureDarwin(interfaceName, ip, ipNet)
|
|
||||||
case "windows":
|
|
||||||
return configureWindows(interfaceName, ip, ipNet)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
|
||||||
logger.Info("Configuring Windows interface: %s", interfaceName)
|
|
||||||
|
|
||||||
// Calculate mask string (e.g., 255.255.255.0)
|
|
||||||
maskBits, _ := ipNet.Mask.Size()
|
|
||||||
mask := net.CIDRMask(maskBits, 32)
|
|
||||||
maskIP := net.IP(mask)
|
|
||||||
|
|
||||||
// Set the IP address using netsh
|
|
||||||
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
|
|
||||||
fmt.Sprintf("name=%s", interfaceName),
|
|
||||||
"source=static",
|
|
||||||
fmt.Sprintf("addr=%s", ip.String()),
|
|
||||||
fmt.Sprintf("mask=%s", maskIP.String()))
|
|
||||||
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
|
|
||||||
// But we'll explicitly enable it to be sure
|
|
||||||
cmd = exec.Command("netsh", "interface", "set", "interface",
|
|
||||||
interfaceName,
|
|
||||||
"admin=enable")
|
|
||||||
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
out, err = cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// delay 2 seconds
|
|
||||||
time.Sleep(8 * time.Second)
|
|
||||||
|
|
||||||
// Wait for the interface to be up and have the correct IP
|
|
||||||
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitForInterfaceUp polls the network interface until it's up or times out
|
|
||||||
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
|
||||||
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
|
||||||
deadline := time.Now().Add(timeout)
|
|
||||||
pollInterval := 500 * time.Millisecond
|
|
||||||
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
// Check if interface exists and is up
|
|
||||||
iface, err := net.InterfaceByName(interfaceName)
|
|
||||||
if err == nil {
|
|
||||||
// Check if interface is up
|
|
||||||
if iface.Flags&net.FlagUp != 0 {
|
|
||||||
// Check if it has the expected IP
|
|
||||||
addrs, err := iface.Addrs()
|
|
||||||
if err == nil {
|
|
||||||
for _, addr := range addrs {
|
|
||||||
ipNet, ok := addr.(*net.IPNet)
|
|
||||||
if ok && ipNet.IP.Equal(expectedIP) {
|
|
||||||
logger.Info("Interface %s is up with correct IP", interfaceName)
|
|
||||||
return nil // Interface is up with correct IP
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait before next check
|
|
||||||
time.Sleep(pollInterval)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
|
||||||
}
|
|
||||||
|
|
||||||
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var cmd *exec.Cmd
|
|
||||||
|
|
||||||
// Parse destination to get the IP and subnet
|
|
||||||
ip, ipNet, err := net.ParseCIDR(destination)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid destination address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the subnet mask
|
|
||||||
maskBits, _ := ipNet.Mask.Size()
|
|
||||||
mask := net.CIDRMask(maskBits, 32)
|
|
||||||
maskIP := net.IP(mask)
|
|
||||||
|
|
||||||
if gateway != "" {
|
|
||||||
// Route with specific gateway
|
|
||||||
cmd = exec.Command("route", "add",
|
|
||||||
ip.String(),
|
|
||||||
"mask", maskIP.String(),
|
|
||||||
gateway,
|
|
||||||
"metric", "1")
|
|
||||||
} else if interfaceName != "" {
|
|
||||||
// First, get the interface index
|
|
||||||
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
|
|
||||||
output, err := indexCmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the output to find the interface index
|
|
||||||
lines := strings.Split(string(output), "\n")
|
|
||||||
var ifIndex string
|
|
||||||
for _, line := range lines {
|
|
||||||
if strings.Contains(line, interfaceName) {
|
|
||||||
fields := strings.Fields(line)
|
|
||||||
if len(fields) > 0 {
|
|
||||||
ifIndex = fields[0]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if ifIndex == "" {
|
|
||||||
return fmt.Errorf("could not find index for interface %s", interfaceName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Convert to integer to validate
|
|
||||||
idx, err := strconv.Atoi(ifIndex)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid interface index: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Route via interface using the index
|
|
||||||
cmd = exec.Command("route", "add",
|
|
||||||
ip.String(),
|
|
||||||
"mask", maskIP.String(),
|
|
||||||
"0.0.0.0",
|
|
||||||
"if", strconv.Itoa(idx))
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("either gateway or interface must be specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func WindowsRemoveRoute(destination string) error {
|
|
||||||
// Parse destination to get the IP
|
|
||||||
ip, ipNet, err := net.ParseCIDR(destination)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("invalid destination address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate the subnet mask
|
|
||||||
maskBits, _ := ipNet.Mask.Size()
|
|
||||||
mask := net.CIDRMask(maskBits, 32)
|
|
||||||
maskIP := net.IP(mask)
|
|
||||||
|
|
||||||
cmd := exec.Command("route", "delete",
|
|
||||||
ip.String(),
|
|
||||||
"mask", maskIP.String())
|
|
||||||
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func findUnusedUTUN() (string, error) {
|
|
||||||
ifaces, err := net.Interfaces()
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
|
||||||
}
|
|
||||||
used := make(map[int]bool)
|
|
||||||
re := regexp.MustCompile(`^utun(\d+)$`)
|
|
||||||
for _, iface := range ifaces {
|
|
||||||
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
|
||||||
if num, err := strconv.Atoi(matches[1]); err == nil {
|
|
||||||
used[num] = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Try utun0 up to utun255.
|
|
||||||
for i := 0; i < 256; i++ {
|
|
||||||
if !used[i] {
|
|
||||||
return fmt.Sprintf("utun%d", i), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("no unused utun interface found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
|
||||||
logger.Info("Configuring darwin interface: %s", interfaceName)
|
|
||||||
|
|
||||||
prefix, _ := ipNet.Mask.Size()
|
|
||||||
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
|
||||||
|
|
||||||
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bring up the interface
|
|
||||||
cmd = exec.Command("ifconfig", interfaceName, "up")
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
|
|
||||||
out, err = cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
|
||||||
// Get the interface
|
|
||||||
link, err := netlink.LinkByName(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create the IP address attributes
|
|
||||||
addr := &netlink.Addr{
|
|
||||||
IPNet: &net.IPNet{
|
|
||||||
IP: ip,
|
|
||||||
Mask: ipNet.Mask,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the IP address to the interface
|
|
||||||
if err := netlink.AddrAdd(link, addr); err != nil {
|
|
||||||
return fmt.Errorf("failed to add IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bring up the interface
|
|
||||||
if err := netlink.LinkSetUp(link); err != nil {
|
|
||||||
return fmt.Errorf("failed to bring up interface: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
|
||||||
if runtime.GOOS != "darwin" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var cmd *exec.Cmd
|
|
||||||
|
|
||||||
if gateway != "" {
|
|
||||||
// Route with specific gateway
|
|
||||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
|
||||||
} else if interfaceName != "" {
|
|
||||||
// Route via interface
|
|
||||||
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("either gateway or interface must be specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DarwinRemoveRoute(destination string) error {
|
|
||||||
if runtime.GOOS != "darwin" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var cmd *exec.Cmd
|
|
||||||
|
|
||||||
if gateway != "" {
|
|
||||||
// Route with specific gateway
|
|
||||||
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
|
||||||
} else if interfaceName != "" {
|
|
||||||
// Route via interface
|
|
||||||
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("either gateway or interface must be specified")
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func LinuxRemoveRoute(destination string) error {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd := exec.Command("ip", "route", "del", destination)
|
|
||||||
logger.Info("Running command: %v", cmd)
|
|
||||||
|
|
||||||
out, err := cmd.CombinedOutput()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRouteForServerIP adds an OS-specific route for the server IP
|
|
||||||
func addRouteForServerIP(serverIP, interfaceName string) error {
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
return DarwinAddRoute(serverIP, "", interfaceName)
|
|
||||||
}
|
|
||||||
// else if runtime.GOOS == "windows" {
|
|
||||||
// return WindowsAddRoute(serverIP, "", interfaceName)
|
|
||||||
// } else if runtime.GOOS == "linux" {
|
|
||||||
// return LinuxAddRoute(serverIP, "", interfaceName)
|
|
||||||
// }
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRouteForServerIP removes an OS-specific route for the server IP
|
|
||||||
func removeRouteForServerIP(serverIP string) error {
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
return DarwinRemoveRoute(serverIP)
|
|
||||||
}
|
|
||||||
// else if runtime.GOOS == "windows" {
|
|
||||||
// return WindowsRemoveRoute(serverIP)
|
|
||||||
// } else if runtime.GOOS == "linux" {
|
|
||||||
// return LinuxRemoveRoute(serverIP)
|
|
||||||
// }
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
|
|
||||||
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
|
|
||||||
if remoteSubnets == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split remote subnets by comma and add routes for each one
|
|
||||||
subnets := strings.Split(remoteSubnets, ",")
|
|
||||||
for _, subnet := range subnets {
|
|
||||||
subnet = strings.TrimSpace(subnet)
|
|
||||||
if subnet == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add route based on operating system
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
|
||||||
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if runtime.GOOS == "windows" {
|
|
||||||
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
|
||||||
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if runtime.GOOS == "linux" {
|
|
||||||
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
|
||||||
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Added route for remote subnet: %s", subnet)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
|
|
||||||
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
|
|
||||||
if remoteSubnets == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split remote subnets by comma and remove routes for each one
|
|
||||||
subnets := strings.Split(remoteSubnets, ",")
|
|
||||||
for _, subnet := range subnets {
|
|
||||||
subnet = strings.TrimSpace(subnet)
|
|
||||||
if subnet == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove route based on operating system
|
|
||||||
if runtime.GOOS == "darwin" {
|
|
||||||
if err := DarwinRemoveRoute(subnet); err != nil {
|
|
||||||
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if runtime.GOOS == "windows" {
|
|
||||||
if err := WindowsRemoveRoute(subnet); err != nil {
|
|
||||||
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if runtime.GOOS == "linux" {
|
|
||||||
if err := LinuxRemoveRoute(subnet); err != nil {
|
|
||||||
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Removed route for remote subnet: %s", subnet)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|||||||
213
olm/interface.go
Normal file
213
olm/interface.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"regexp"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/network"
|
||||||
|
"github.com/vishvananda/netlink"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigureInterface configures a network interface with an IP address and brings it up
|
||||||
|
func ConfigureInterface(interfaceName string, wgData WgData) error {
|
||||||
|
if interfaceName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ipAddr string = wgData.TunnelIP
|
||||||
|
|
||||||
|
// Parse the IP address and network
|
||||||
|
ip, ipNet, err := net.ParseCIDR(ipAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||||
|
mask := net.IP(ipNet.Mask).String()
|
||||||
|
destinationAddress := ipNet.IP.String()
|
||||||
|
|
||||||
|
// network.SetTunnelRemoteAddress() // what does this do?
|
||||||
|
network.SetIPv4Settings([]string{destinationAddress}, []string{mask})
|
||||||
|
apiServer.SetTunnelIP(destinationAddress)
|
||||||
|
|
||||||
|
if interfaceName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "linux":
|
||||||
|
return configureLinux(interfaceName, ip, ipNet)
|
||||||
|
case "darwin":
|
||||||
|
return configureDarwin(interfaceName, ip, ipNet)
|
||||||
|
case "windows":
|
||||||
|
return configureWindows(interfaceName, ip, ipNet)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureWindows(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring Windows interface: %s", interfaceName)
|
||||||
|
|
||||||
|
// Calculate mask string (e.g., 255.255.255.0)
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
// Set the IP address using netsh
|
||||||
|
cmd := exec.Command("netsh", "interface", "ipv4", "set", "address",
|
||||||
|
fmt.Sprintf("name=%s", interfaceName),
|
||||||
|
"source=static",
|
||||||
|
fmt.Sprintf("addr=%s", ip.String()),
|
||||||
|
fmt.Sprintf("mask=%s", maskIP.String()))
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("netsh command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface if needed (in Windows, setting the IP usually brings it up)
|
||||||
|
// But we'll explicitly enable it to be sure
|
||||||
|
cmd = exec.Command("netsh", "interface", "set", "interface",
|
||||||
|
interfaceName,
|
||||||
|
"admin=enable")
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err = cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("netsh enable interface command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// delay 2 seconds
|
||||||
|
time.Sleep(8 * time.Second)
|
||||||
|
|
||||||
|
// Wait for the interface to be up and have the correct IP
|
||||||
|
err = waitForInterfaceUp(interfaceName, ip, 30*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("interface did not come up within timeout: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// waitForInterfaceUp polls the network interface until it's up or times out
|
||||||
|
func waitForInterfaceUp(interfaceName string, expectedIP net.IP, timeout time.Duration) error {
|
||||||
|
logger.Info("Waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
pollInterval := 500 * time.Millisecond
|
||||||
|
|
||||||
|
for time.Now().Before(deadline) {
|
||||||
|
// Check if interface exists and is up
|
||||||
|
iface, err := net.InterfaceByName(interfaceName)
|
||||||
|
if err == nil {
|
||||||
|
// Check if interface is up
|
||||||
|
if iface.Flags&net.FlagUp != 0 {
|
||||||
|
// Check if it has the expected IP
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err == nil {
|
||||||
|
for _, addr := range addrs {
|
||||||
|
ipNet, ok := addr.(*net.IPNet)
|
||||||
|
if ok && ipNet.IP.Equal(expectedIP) {
|
||||||
|
logger.Info("Interface %s is up with correct IP", interfaceName)
|
||||||
|
return nil // Interface is up with correct IP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logger.Info("Interface %s is up but doesn't have expected IP yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s exists but is not up yet", interfaceName)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logger.Info("Interface %s not found yet: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait before next check
|
||||||
|
time.Sleep(pollInterval)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("timed out waiting for interface %s to be up with IP %s", interfaceName, expectedIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func findUnusedUTUN() (string, error) {
|
||||||
|
ifaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to list interfaces: %v", err)
|
||||||
|
}
|
||||||
|
used := make(map[int]bool)
|
||||||
|
re := regexp.MustCompile(`^utun(\d+)$`)
|
||||||
|
for _, iface := range ifaces {
|
||||||
|
if matches := re.FindStringSubmatch(iface.Name); len(matches) == 2 {
|
||||||
|
if num, err := strconv.Atoi(matches[1]); err == nil {
|
||||||
|
used[num] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Try utun0 up to utun255.
|
||||||
|
for i := 0; i < 256; i++ {
|
||||||
|
if !used[i] {
|
||||||
|
return fmt.Sprintf("utun%d", i), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("no unused utun interface found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureDarwin(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
logger.Info("Configuring darwin interface: %s", interfaceName)
|
||||||
|
|
||||||
|
prefix, _ := ipNet.Mask.Size()
|
||||||
|
ipStr := fmt.Sprintf("%s/%d", ip.String(), prefix)
|
||||||
|
|
||||||
|
cmd := exec.Command("ifconfig", interfaceName, "inet", ipStr, ip.String(), "alias")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
cmd = exec.Command("ifconfig", interfaceName, "up")
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err = cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ifconfig up command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func configureLinux(interfaceName string, ip net.IP, ipNet *net.IPNet) error {
|
||||||
|
// Get the interface
|
||||||
|
link, err := netlink.LinkByName(interfaceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface %s: %v", interfaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the IP address attributes
|
||||||
|
addr := &netlink.Addr{
|
||||||
|
IPNet: &net.IPNet{
|
||||||
|
IP: ip,
|
||||||
|
Mask: ipNet.Mask,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the IP address to the interface
|
||||||
|
if err := netlink.AddrAdd(link, addr); err != nil {
|
||||||
|
return fmt.Errorf("failed to add IP address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the interface
|
||||||
|
if err := netlink.LinkSetUp(link); err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
92
olm/olm.go
92
olm/olm.go
@@ -4,9 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/bind"
|
"github.com/fosrl/newt/bind"
|
||||||
@@ -15,6 +13,7 @@ import (
|
|||||||
"github.com/fosrl/newt/updates"
|
"github.com/fosrl/newt/updates"
|
||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/olm/api"
|
"github.com/fosrl/olm/api"
|
||||||
|
"github.com/fosrl/olm/network"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
@@ -57,7 +56,8 @@ type Config struct {
|
|||||||
OrgID string
|
OrgID string
|
||||||
// DoNotCreateNewClient bool
|
// DoNotCreateNewClient bool
|
||||||
|
|
||||||
FileDescriptorTun uint32
|
FileDescriptorTun uint32
|
||||||
|
FileDescriptorUAPI uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -82,6 +82,7 @@ func Run(ctx context.Context, config Config) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||||
|
network.SetMTU(config.MTU)
|
||||||
|
|
||||||
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
if err := updates.CheckForUpdate("fosrl", "olm", config.Version); err != nil {
|
||||||
logger.Debug("Failed to check for updates: %v", err)
|
logger.Debug("Failed to check for updates: %v", err)
|
||||||
@@ -371,14 +372,14 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
if config.FileDescriptorTun != 0 {
|
if config.FileDescriptorTun != 0 {
|
||||||
return createTUNFromFD(config.FileDescriptorTun, config.MTU)
|
return createTUNFromFD(config.FileDescriptorTun, config.MTU)
|
||||||
}
|
}
|
||||||
|
var ifName = interfaceName
|
||||||
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||||
interfaceName, err := findUnusedUTUN()
|
ifName, err = findUnusedUTUN()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return tun.CreateTUN(interfaceName, config.MTU)
|
|
||||||
}
|
}
|
||||||
return tun.CreateTUN(interfaceName, config.MTU)
|
return tun.CreateTUN(ifName, config.MTU)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -386,45 +387,47 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
if config.FileDescriptorTun == 0 {
|
||||||
interfaceName = realInterfaceName
|
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||||
}
|
interfaceName = realInterfaceName
|
||||||
|
|
||||||
fileUAPI, err := func() (*os.File, error) {
|
|
||||||
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
|
|
||||||
fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return os.NewFile(uintptr(fd), ""), nil
|
|
||||||
}
|
}
|
||||||
return uapiOpen(interfaceName)
|
|
||||||
}()
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("UAPI listen error: %v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
// fileUAPI, err := func() (*os.File, error) {
|
||||||
|
// if config.FileDescriptorUAPI != 0 {
|
||||||
|
// fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32)
|
||||||
|
// if err != nil {
|
||||||
|
// return nil, fmt.Errorf("invalid UAPI file descriptor: %v", err)
|
||||||
|
// }
|
||||||
|
// return os.NewFile(uintptr(fd), ""), nil
|
||||||
|
// }
|
||||||
|
// return uapiOpen(interfaceName)
|
||||||
|
// }()
|
||||||
|
// if err != nil {
|
||||||
|
// logger.Error("UAPI listen error: %v", err)
|
||||||
|
// os.Exit(1)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
|
||||||
uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
dev = device.NewDevice(tdev, sharedBind, device.NewLogger(util.MapToWireGuardLogLevel(loggerLevel), "wireguard: "))
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to listen on uapi socket: %v", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
// uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
for {
|
// if err != nil {
|
||||||
conn, err := uapiListener.Accept()
|
// logger.Error("Failed to listen on uapi socket: %v", err)
|
||||||
if err != nil {
|
// os.Exit(1)
|
||||||
|
// }
|
||||||
|
|
||||||
return
|
// go func() {
|
||||||
}
|
// for {
|
||||||
go dev.IpcHandle(conn)
|
// conn, err := uapiListener.Accept()
|
||||||
}
|
// if err != nil {
|
||||||
}()
|
|
||||||
logger.Info("UAPI listener started")
|
// return
|
||||||
|
// }
|
||||||
|
// go dev.IpcHandle(conn)
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
// logger.Info("UAPI listener started")
|
||||||
|
|
||||||
if err = dev.Up(); err != nil {
|
if err = dev.Up(); err != nil {
|
||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
@@ -432,7 +435,6 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
if err = ConfigureInterface(interfaceName, wgData); err != nil {
|
||||||
logger.Error("Failed to configure interface: %v", err)
|
logger.Error("Failed to configure interface: %v", err)
|
||||||
}
|
}
|
||||||
apiServer.SetTunnelIP(wgData.TunnelIP)
|
|
||||||
|
|
||||||
peerMonitor = peermonitor.NewPeerMonitor(
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
@@ -476,10 +478,10 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
logger.Error("Failed to add route for peer: %v", err)
|
logger.Error("Failed to add route for peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
// if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
||||||
logger.Error("Failed to add routes for remote subnets: %v", err)
|
// logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
return
|
// return
|
||||||
}
|
// }
|
||||||
|
|
||||||
logger.Info("Configured peer %s", site.PublicKey)
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
}
|
}
|
||||||
@@ -671,7 +673,7 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove route for the peer
|
// Remove route for the peer
|
||||||
err = removeRouteForServerIP(peerToRemove.ServerIP)
|
err = removeRouteForServerIP(peerToRemove.ServerIP, interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to remove route for peer: %v", err)
|
logger.Error("Failed to remove route for peer: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
121
olm/peer.go
Normal file
121
olm/peer.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/util"
|
||||||
|
"github.com/fosrl/olm/peermonitor"
|
||||||
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigurePeer sets up or updates a peer within the WireGuard device
|
||||||
|
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error {
|
||||||
|
siteHost, err := util.ResolveDomain(siteConfig.Endpoint)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split off the CIDR of the server IP which is just a string and add /32 for the allowed IP
|
||||||
|
allowedIp := strings.Split(siteConfig.ServerIP, "/")
|
||||||
|
if len(allowedIp) > 1 {
|
||||||
|
allowedIp[1] = "32"
|
||||||
|
} else {
|
||||||
|
allowedIp = append(allowedIp, "32")
|
||||||
|
}
|
||||||
|
allowedIpStr := strings.Join(allowedIp, "/")
|
||||||
|
|
||||||
|
// Collect all allowed IPs in a slice
|
||||||
|
var allowedIPs []string
|
||||||
|
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||||
|
|
||||||
|
// If we have anything in remoteSubnets, add those as well
|
||||||
|
if siteConfig.RemoteSubnets != "" {
|
||||||
|
// Split remote subnets by comma and add each one
|
||||||
|
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
|
||||||
|
for _, subnet := range remoteSubnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet != "" {
|
||||||
|
allowedIPs = append(allowedIPs, subnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Construct WireGuard config for this peer
|
||||||
|
var configBuilder strings.Builder
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", util.FixKey(privateKey.String())))
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(siteConfig.PublicKey)))
|
||||||
|
|
||||||
|
// Add each allowed IP separately
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||||
|
}
|
||||||
|
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||||
|
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
||||||
|
|
||||||
|
config := configBuilder.String()
|
||||||
|
logger.Debug("Configuring peer with config: %s", config)
|
||||||
|
|
||||||
|
err = dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to configure WireGuard peer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up peer monitoring
|
||||||
|
if peerMonitor != nil {
|
||||||
|
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||||
|
monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port
|
||||||
|
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
|
||||||
|
primaryRelay, err := util.ResolveDomain(endpoint) // Using global endpoint variable
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
wgConfig := &peermonitor.WireGuardConfig{
|
||||||
|
SiteID: siteConfig.SiteId,
|
||||||
|
PublicKey: util.FixKey(siteConfig.PublicKey),
|
||||||
|
ServerIP: strings.Split(siteConfig.ServerIP, "/")[0],
|
||||||
|
Endpoint: siteConfig.Endpoint,
|
||||||
|
PrimaryRelay: primaryRelay,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig)
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeer removes a peer from the WireGuard device
|
||||||
|
func RemovePeer(dev *device.Device, siteId int, publicKey string) error {
|
||||||
|
// Construct WireGuard config to remove the peer
|
||||||
|
var configBuilder strings.Builder
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
|
||||||
|
configBuilder.WriteString("remove=true\n")
|
||||||
|
|
||||||
|
config := configBuilder.String()
|
||||||
|
logger.Debug("Removing peer with config: %s", config)
|
||||||
|
|
||||||
|
err := dev.IpcSet(config)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to remove WireGuard peer: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop monitoring this peer
|
||||||
|
if peerMonitor != nil {
|
||||||
|
peerMonitor.RemovePeer(siteId)
|
||||||
|
logger.Info("Stopped monitoring for site %d", siteId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
358
olm/route.go
Normal file
358
olm/route.go
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/olm/network"
|
||||||
|
)
|
||||||
|
|
||||||
|
func DarwinAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-gateway", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("route", "-q", "-n", "add", "-inet", destination, "-interface", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DarwinRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "darwin" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "-q", "-n", "delete", "-inet", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("ip", "route", "del", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
// Parse destination to get the IP and subnet
|
||||||
|
ip, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the subnet mask
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("route", "add",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String(),
|
||||||
|
gateway,
|
||||||
|
"metric", "1")
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// First, get the interface index
|
||||||
|
indexCmd := exec.Command("netsh", "interface", "ipv4", "show", "interfaces")
|
||||||
|
output, err := indexCmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get interface index: %v, output: %s", err, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the output to find the interface index
|
||||||
|
lines := strings.Split(string(output), "\n")
|
||||||
|
var ifIndex string
|
||||||
|
for _, line := range lines {
|
||||||
|
if strings.Contains(line, interfaceName) {
|
||||||
|
fields := strings.Fields(line)
|
||||||
|
if len(fields) > 0 {
|
||||||
|
ifIndex = fields[0]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ifIndex == "" {
|
||||||
|
return fmt.Errorf("could not find index for interface %s", interfaceName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to integer to validate
|
||||||
|
idx, err := strconv.Atoi(ifIndex)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid interface index: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route via interface using the index
|
||||||
|
cmd = exec.Command("route", "add",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String(),
|
||||||
|
"0.0.0.0",
|
||||||
|
"if", strconv.Itoa(idx))
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func WindowsRemoveRoute(destination string) error {
|
||||||
|
// Parse destination to get the IP
|
||||||
|
ip, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the subnet mask
|
||||||
|
maskBits, _ := ipNet.Mask.Size()
|
||||||
|
mask := net.CIDRMask(maskBits, 32)
|
||||||
|
maskIP := net.IP(mask)
|
||||||
|
|
||||||
|
cmd := exec.Command("route", "delete",
|
||||||
|
ip.String(),
|
||||||
|
"mask", maskIP.String())
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||||
|
func addRouteForServerIP(serverIP, interfaceName string) error {
|
||||||
|
if err := addRouteForNetworkConfig(serverIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if interfaceName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||||
|
func removeRouteForServerIP(serverIP string, interfaceName string) error {
|
||||||
|
if err := removeRouteForNetworkConfig(serverIP); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if interfaceName == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinRemoveRoute(serverIP)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsRemoveRoute(serverIP)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxRemoveRoute(serverIP)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func addRouteForNetworkConfig(destination string) error {
|
||||||
|
// Parse the subnet to extract IP and mask
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||||
|
mask := net.IP(ipNet.Mask).String()
|
||||||
|
destinationAddress := ipNet.IP.String()
|
||||||
|
|
||||||
|
network.AddIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func removeRouteForNetworkConfig(destination string) error {
|
||||||
|
// Parse the subnet to extract IP and mask
|
||||||
|
_, ipNet, err := net.ParseCIDR(destination)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse subnet %s: %v", destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert CIDR mask to dotted decimal format (e.g., 255.255.255.0)
|
||||||
|
mask := net.IP(ipNet.Mask).String()
|
||||||
|
destinationAddress := ipNet.IP.String()
|
||||||
|
|
||||||
|
network.RemoveIPv4IncludedRoute(network.IPv4Route{DestinationAddress: destinationAddress, SubnetMask: mask})
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and add routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := addRouteForNetworkConfig(subnet); err != nil {
|
||||||
|
logger.Error("Failed to add network config for subnet %s: %v", subnet, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add route based on operating system
|
||||||
|
if interfaceName == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Added route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and remove routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := removeRouteForNetworkConfig(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove network config for subnet %s: %v", subnet, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
91
olm/types.go
Normal file
91
olm/types.go
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
package olm
|
||||||
|
|
||||||
|
import "github.com/fosrl/olm/peermonitor"
|
||||||
|
|
||||||
|
type WgData struct {
|
||||||
|
Sites []SiteConfig `json:"sites"`
|
||||||
|
TunnelIP string `json:"tunnelIP"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SiteConfig struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetsByType struct {
|
||||||
|
UDP []string `json:"udp"`
|
||||||
|
TCP []string `json:"tcp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TargetData struct {
|
||||||
|
Targets []string `json:"targets"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchMessage struct {
|
||||||
|
NewtID string `json:"newtId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type HolePunchData struct {
|
||||||
|
ExitNodes []ExitNode `json:"exitNodes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type EncryptedHolePunchMessage struct {
|
||||||
|
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||||
|
Nonce []byte `json:"nonce"`
|
||||||
|
Ciphertext []byte `json:"ciphertext"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
peerMonitor *peermonitor.PeerMonitor
|
||||||
|
stopHolepunch chan struct{}
|
||||||
|
stopRegister func()
|
||||||
|
stopPing chan struct{}
|
||||||
|
olmToken string
|
||||||
|
holePunchRunning bool
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerAction represents a request to add, update, or remove a peer
|
||||||
|
type PeerAction struct {
|
||||||
|
Action string `json:"action"` // "add", "update", or "remove"
|
||||||
|
SiteInfo SiteConfig `json:"siteInfo"` // Site configuration information
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePeerData represents the data needed to update a peer
|
||||||
|
type UpdatePeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddPeerData represents the data needed to add a peer
|
||||||
|
type AddPeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
ServerIP string `json:"serverIP"`
|
||||||
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemovePeerData represents the data needed to remove a peer
|
||||||
|
type RemovePeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type RelayPeerData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
}
|
||||||
12
olm/unix.go
12
olm/unix.go
@@ -6,20 +6,26 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
func createTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||||
err := unix.SetNonblock(int(tunFd), true)
|
dupTunFd, err := unix.Dup(int(tunFd))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Unable to dup tun fd: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = unix.SetNonblock(dupTunFd, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
file := os.NewFile(uintptr(tunFd), "")
|
return tun.CreateTUNFromFile(os.NewFile(uintptr(dupTunFd), "/dev/tun"), mtuInt)
|
||||||
return tun.CreateTUNFromFile(file, mtuInt)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func uapiOpen(interfaceName string) (*os.File, error) {
|
func uapiOpen(interfaceName string) (*os.File, error) {
|
||||||
return ipc.UAPIOpen(interfaceName)
|
return ipc.UAPIOpen(interfaceName)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user