From 516eae6d96f841ac9fd08aadfccc388ac3226cda Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 27 Jul 2025 14:49:57 -0700 Subject: [PATCH] Handle remote routing --- common.go | 200 +++++++++++++++++++++++++++++++++++++++++++++++++----- main.go | 87 ++++++++++++++++-------- 2 files changed, 242 insertions(+), 45 deletions(-) diff --git a/common.go b/common.go index b3d1ce4..db8c155 100644 --- a/common.go +++ b/common.go @@ -31,11 +31,12 @@ type WgData struct { } type SiteConfig struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access } type TargetsByType struct { @@ -91,20 +92,22 @@ type PeerAction struct { // UpdatePeerData represents the data needed to update a peer type UpdatePeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access } // AddPeerData represents the data needed to add a peer type AddPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access } // RemovePeerData represents the data needed to remove a peer @@ -467,11 +470,32 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } allowedIpStr := strings.Join(allowedIp, "/") + // Collect all allowed IPs in a slice + var allowedIPs []string + allowedIPs = append(allowedIPs, allowedIpStr) + + // If we have anything in remoteSubnets, add those as well + if siteConfig.RemoteSubnets != "" { + // Split remote subnets by comma and add each one + remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",") + for _, subnet := range remoteSubnets { + subnet = strings.TrimSpace(subnet) + if subnet != "" { + allowedIPs = append(allowedIPs, subnet) + } + } + } + // Construct WireGuard config for this peer var configBuilder strings.Builder configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String()))) configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey))) - configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIpStr)) + + // Add each allowed IP separately + for _, allowedIP := range allowedIPs { + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP)) + } + configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) configBuilder.WriteString("persistent_keepalive_interval=1\n") @@ -487,7 +511,6 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes if peerMonitor != nil { monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] monitorPeer := fmt.Sprintf("%s:%d", monitorAddress, siteConfig.ServerPort+1) // +1 for the monitor port - logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable @@ -862,3 +885,146 @@ func DarwinRemoveRoute(destination string) error { return nil } + +func LinuxAddRoute(destination string, gateway string, interfaceName string) error { + if runtime.GOOS != "linux" { + return nil + } + + var cmd *exec.Cmd + + if gateway != "" { + // Route with specific gateway + cmd = exec.Command("ip", "route", "add", destination, "via", gateway) + } else if interfaceName != "" { + // Route via interface + cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName) + } else { + return fmt.Errorf("either gateway or interface must be specified") + } + + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ip route command failed: %v, output: %s", err, out) + } + + return nil +} + +func LinuxRemoveRoute(destination string) error { + if runtime.GOOS != "linux" { + return nil + } + + cmd := exec.Command("ip", "route", "del", destination) + logger.Info("Running command: %v", cmd) + + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out) + } + + return nil +} + +// addRouteForServerIP adds an OS-specific route for the server IP +func addRouteForServerIP(serverIP, interfaceName string) error { + if runtime.GOOS == "darwin" { + return DarwinAddRoute(serverIP, "", interfaceName) + } + // else if runtime.GOOS == "windows" { + // return WindowsAddRoute(serverIP, "", interfaceName) + // } else if runtime.GOOS == "linux" { + // return LinuxAddRoute(serverIP, "", interfaceName) + // } + return nil +} + +// removeRouteForServerIP removes an OS-specific route for the server IP +func removeRouteForServerIP(serverIP string) error { + if runtime.GOOS == "darwin" { + return DarwinRemoveRoute(serverIP) + } + // else if runtime.GOOS == "windows" { + // return WindowsRemoveRoute(serverIP) + // } else if runtime.GOOS == "linux" { + // return LinuxRemoveRoute(serverIP) + // } + return nil +} + +// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets +func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error { + if remoteSubnets == "" { + return nil + } + + // Split remote subnets by comma and add routes for each one + subnets := strings.Split(remoteSubnets, ",") + for _, subnet := range subnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + // Add route based on operating system + if runtime.GOOS == "darwin" { + if err := DarwinAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxAddRoute(subnet, "", interfaceName); err != nil { + logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Added route for remote subnet: %s", subnet) + } + return nil +} + +// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets +func removeRoutesForRemoteSubnets(remoteSubnets string) error { + if remoteSubnets == "" { + return nil + } + + // Split remote subnets by comma and remove routes for each one + subnets := strings.Split(remoteSubnets, ",") + for _, subnet := range subnets { + subnet = strings.TrimSpace(subnet) + if subnet == "" { + continue + } + + // Remove route based on operating system + if runtime.GOOS == "darwin" { + if err := DarwinRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "windows" { + if err := WindowsRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err) + return err + } + } else if runtime.GOOS == "linux" { + if err := LinuxRemoveRoute(subnet); err != nil { + logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err) + return err + } + } + + logger.Info("Removed route for remote subnet: %s", subnet) + } + return nil +} diff --git a/main.go b/main.go index 265433e..ebb76af 100644 --- a/main.go +++ b/main.go @@ -450,6 +450,11 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { stopRegister = nil } + close(stopHolepunch) + + // wait 10 milliseconds to ensure the previous connection is closed + time.Sleep(10 * time.Millisecond) + // if there is an existing tunnel then close it if dev != nil { logger.Info("Got new message. Closing existing tunnel!") @@ -544,8 +549,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { logger.Info("UAPI listener started") - close(stopHolepunch) - // Bring up the device err = dev.Up() if err != nil { @@ -586,16 +589,17 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { return } - err = DarwinAddRoute(site.ServerIP, "", interfaceName) + err = addRouteForServerIP(site.ServerIP, interfaceName) if err != nil { logger.Error("Failed to add route for peer: %v", err) return } - // err = WindowsAddRoute(site.ServerIP, "", interfaceName) - // if err != nil { - // logger.Error("Failed to add route for peer: %v", err) - // return - // } + + // Add routes for remote subnets + if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + return + } logger.Info("Configured peer %s", site.PublicKey) } @@ -622,21 +626,45 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { // Convert to SiteConfig siteConfig := SiteConfig{ - SiteId: updateData.SiteId, - Endpoint: updateData.Endpoint, - PublicKey: updateData.PublicKey, - ServerIP: updateData.ServerIP, - ServerPort: updateData.ServerPort, + SiteId: updateData.SiteId, + Endpoint: updateData.Endpoint, + PublicKey: updateData.PublicKey, + ServerIP: updateData.ServerIP, + ServerPort: updateData.ServerPort, + RemoteSubnets: updateData.RemoteSubnets, } // Update the peer in WireGuard if dev != nil { + // Find the existing peer to get old RemoteSubnets + var oldRemoteSubnets string + for _, site := range wgData.Sites { + if site.SiteId == updateData.SiteId { + oldRemoteSubnets = site.RemoteSubnets + break + } + } + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err) // Send error response if needed return } + // Remove old remote subnet routes if they changed + if oldRemoteSubnets != siteConfig.RemoteSubnets { + if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { + logger.Error("Failed to remove old remote subnet routes: %v", err) + // Continue anyway to add new routes + } + + // Add new remote subnet routes + if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add new remote subnet routes: %v", err) + return + } + } + // Update successful logger.Info("Successfully updated peer for site %d", updateData.SiteId) // If this is part of a WgData structure, update it @@ -669,11 +697,12 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { // Convert to SiteConfig siteConfig := SiteConfig{ - SiteId: addData.SiteId, - Endpoint: addData.Endpoint, - PublicKey: addData.PublicKey, - ServerIP: addData.ServerIP, - ServerPort: addData.ServerPort, + SiteId: addData.SiteId, + Endpoint: addData.Endpoint, + PublicKey: addData.PublicKey, + ServerIP: addData.ServerIP, + ServerPort: addData.ServerPort, + RemoteSubnets: addData.RemoteSubnets, } // Add the peer to WireGuard @@ -684,16 +713,17 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { } // Add route for the new peer - err = DarwinAddRoute(siteConfig.ServerIP, "", interfaceName) + err = addRouteForServerIP(siteConfig.ServerIP, interfaceName) if err != nil { logger.Error("Failed to add route for new peer: %v", err) return } - // err = WindowsAddRoute(siteConfig.ServerIP, "", interfaceName) - // if err != nil { - // logger.Error("Failed to add route for new peer: %v", err) - // return - // } + + // Add routes for remote subnets + if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + return + } // Add successful logger.Info("Successfully added peer for site %d", addData.SiteId) @@ -747,14 +777,15 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { } // Remove route for the peer - err = DarwinRemoveRoute(peerToRemove.ServerIP) + err = removeRouteForServerIP(peerToRemove.ServerIP) if err != nil { logger.Error("Failed to remove route for peer: %v", err) return } - err = WindowsRemoveRoute(peerToRemove.ServerIP) - if err != nil { - logger.Error("Failed to remove route for peer: %v", err) + + // Remove routes for remote subnets + if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil { + logger.Error("Failed to remove routes for remote subnets: %v", err) return }