mirror of
https://github.com/fosrl/olm.git
synced 2026-02-28 15:56:43 +00:00
Handle remote routing
This commit is contained in:
200
common.go
200
common.go
@@ -31,11 +31,12 @@ type WgData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SiteConfig struct {
|
type SiteConfig struct {
|
||||||
SiteId int `json:"siteId"`
|
SiteId int `json:"siteId"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
ServerIP string `json:"serverIP"`
|
ServerIP string `json:"serverIP"`
|
||||||
ServerPort uint16 `json:"serverPort"`
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
}
|
}
|
||||||
|
|
||||||
type TargetsByType struct {
|
type TargetsByType struct {
|
||||||
@@ -91,20 +92,22 @@ type PeerAction struct {
|
|||||||
|
|
||||||
// UpdatePeerData represents the data needed to update a peer
|
// UpdatePeerData represents the data needed to update a peer
|
||||||
type UpdatePeerData struct {
|
type UpdatePeerData struct {
|
||||||
SiteId int `json:"siteId"`
|
SiteId int `json:"siteId"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
ServerIP string `json:"serverIP"`
|
ServerIP string `json:"serverIP"`
|
||||||
ServerPort uint16 `json:"serverPort"`
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddPeerData represents the data needed to add a peer
|
// AddPeerData represents the data needed to add a peer
|
||||||
type AddPeerData struct {
|
type AddPeerData struct {
|
||||||
SiteId int `json:"siteId"`
|
SiteId int `json:"siteId"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
ServerIP string `json:"serverIP"`
|
ServerIP string `json:"serverIP"`
|
||||||
ServerPort uint16 `json:"serverPort"`
|
ServerPort uint16 `json:"serverPort"`
|
||||||
|
RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeerData represents the data needed to remove a peer
|
// RemovePeerData represents the data needed to remove a peer
|
||||||
@@ -467,11 +470,32 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
|
|||||||
}
|
}
|
||||||
allowedIpStr := strings.Join(allowedIp, "/")
|
allowedIpStr := strings.Join(allowedIp, "/")
|
||||||
|
|
||||||
|
// Collect all allowed IPs in a slice
|
||||||
|
var allowedIPs []string
|
||||||
|
allowedIPs = append(allowedIPs, allowedIpStr)
|
||||||
|
|
||||||
|
// If we have anything in remoteSubnets, add those as well
|
||||||
|
if siteConfig.RemoteSubnets != "" {
|
||||||
|
// Split remote subnets by comma and add each one
|
||||||
|
remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",")
|
||||||
|
for _, subnet := range remoteSubnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet != "" {
|
||||||
|
allowedIPs = append(allowedIPs, subnet)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Construct WireGuard config for this peer
|
// Construct WireGuard config for this peer
|
||||||
var configBuilder strings.Builder
|
var configBuilder strings.Builder
|
||||||
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String())))
|
configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String())))
|
||||||
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey)))
|
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(siteConfig.PublicKey)))
|
||||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIpStr))
|
|
||||||
|
// Add each allowed IP separately
|
||||||
|
for _, allowedIP := range allowedIPs {
|
||||||
|
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIP))
|
||||||
|
}
|
||||||
|
|
||||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||||
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
||||||
|
|
||||||
@@ -487,7 +511,6 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
|
|||||||
if peerMonitor != nil {
|
if peerMonitor != nil {
|
||||||
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0]
|
||||||
monitorPeer := fmt.Sprintf("%s:%d", monitorAddress, siteConfig.ServerPort+1) // +1 for the monitor port
|
monitorPeer := fmt.Sprintf("%s:%d", monitorAddress, siteConfig.ServerPort+1) // +1 for the monitor port
|
||||||
|
|
||||||
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer)
|
||||||
|
|
||||||
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
|
primaryRelay, err := resolveDomain(endpoint) // Using global endpoint variable
|
||||||
@@ -862,3 +885,146 @@ func DarwinRemoveRoute(destination string) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func LinuxAddRoute(destination string, gateway string, interfaceName string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
if gateway != "" {
|
||||||
|
// Route with specific gateway
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "via", gateway)
|
||||||
|
} else if interfaceName != "" {
|
||||||
|
// Route via interface
|
||||||
|
cmd = exec.Command("ip", "route", "add", destination, "dev", interfaceName)
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("either gateway or interface must be specified")
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LinuxRemoveRoute(destination string) error {
|
||||||
|
if runtime.GOOS != "linux" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.Command("ip", "route", "del", destination)
|
||||||
|
logger.Info("Running command: %v", cmd)
|
||||||
|
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("ip route delete command failed: %v, output: %s", err, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRouteForServerIP adds an OS-specific route for the server IP
|
||||||
|
func addRouteForServerIP(serverIP, interfaceName string) error {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinAddRoute(serverIP, "", interfaceName)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsAddRoute(serverIP, "", interfaceName)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxAddRoute(serverIP, "", interfaceName)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRouteForServerIP removes an OS-specific route for the server IP
|
||||||
|
func removeRouteForServerIP(serverIP string) error {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
return DarwinRemoveRoute(serverIP)
|
||||||
|
}
|
||||||
|
// else if runtime.GOOS == "windows" {
|
||||||
|
// return WindowsRemoveRoute(serverIP)
|
||||||
|
// } else if runtime.GOOS == "linux" {
|
||||||
|
// return LinuxRemoveRoute(serverIP)
|
||||||
|
// }
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and add routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxAddRoute(subnet, "", interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Added route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets
|
||||||
|
func removeRoutesForRemoteSubnets(remoteSubnets string) error {
|
||||||
|
if remoteSubnets == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split remote subnets by comma and remove routes for each one
|
||||||
|
subnets := strings.Split(remoteSubnets, ",")
|
||||||
|
for _, subnet := range subnets {
|
||||||
|
subnet = strings.TrimSpace(subnet)
|
||||||
|
if subnet == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove route based on operating system
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if err := DarwinRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Darwin route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "windows" {
|
||||||
|
if err := WindowsRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Windows route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else if runtime.GOOS == "linux" {
|
||||||
|
if err := LinuxRemoveRoute(subnet); err != nil {
|
||||||
|
logger.Error("Failed to remove Linux route for subnet %s: %v", subnet, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Removed route for remote subnet: %s", subnet)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
87
main.go
87
main.go
@@ -450,6 +450,11 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
stopRegister = nil
|
stopRegister = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
close(stopHolepunch)
|
||||||
|
|
||||||
|
// wait 10 milliseconds to ensure the previous connection is closed
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
// if there is an existing tunnel then close it
|
// if there is an existing tunnel then close it
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
logger.Info("Got new message. Closing existing tunnel!")
|
logger.Info("Got new message. Closing existing tunnel!")
|
||||||
@@ -544,8 +549,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
logger.Info("UAPI listener started")
|
logger.Info("UAPI listener started")
|
||||||
|
|
||||||
close(stopHolepunch)
|
|
||||||
|
|
||||||
// Bring up the device
|
// Bring up the device
|
||||||
err = dev.Up()
|
err = dev.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -586,16 +589,17 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = DarwinAddRoute(site.ServerIP, "", interfaceName)
|
err = addRouteForServerIP(site.ServerIP, interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to add route for peer: %v", err)
|
logger.Error("Failed to add route for peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// err = WindowsAddRoute(site.ServerIP, "", interfaceName)
|
|
||||||
// if err != nil {
|
// Add routes for remote subnets
|
||||||
// logger.Error("Failed to add route for peer: %v", err)
|
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
|
||||||
// return
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
// }
|
return
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("Configured peer %s", site.PublicKey)
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
}
|
}
|
||||||
@@ -622,21 +626,45 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
// Convert to SiteConfig
|
// Convert to SiteConfig
|
||||||
siteConfig := SiteConfig{
|
siteConfig := SiteConfig{
|
||||||
SiteId: updateData.SiteId,
|
SiteId: updateData.SiteId,
|
||||||
Endpoint: updateData.Endpoint,
|
Endpoint: updateData.Endpoint,
|
||||||
PublicKey: updateData.PublicKey,
|
PublicKey: updateData.PublicKey,
|
||||||
ServerIP: updateData.ServerIP,
|
ServerIP: updateData.ServerIP,
|
||||||
ServerPort: updateData.ServerPort,
|
ServerPort: updateData.ServerPort,
|
||||||
|
RemoteSubnets: updateData.RemoteSubnets,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the peer in WireGuard
|
// Update the peer in WireGuard
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
|
// Find the existing peer to get old RemoteSubnets
|
||||||
|
var oldRemoteSubnets string
|
||||||
|
for _, site := range wgData.Sites {
|
||||||
|
if site.SiteId == updateData.SiteId {
|
||||||
|
oldRemoteSubnets = site.RemoteSubnets
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
|
||||||
logger.Error("Failed to update peer: %v", err)
|
logger.Error("Failed to update peer: %v", err)
|
||||||
// Send error response if needed
|
// Send error response if needed
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remove old remote subnet routes if they changed
|
||||||
|
if oldRemoteSubnets != siteConfig.RemoteSubnets {
|
||||||
|
if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil {
|
||||||
|
logger.Error("Failed to remove old remote subnet routes: %v", err)
|
||||||
|
// Continue anyway to add new routes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add new remote subnet routes
|
||||||
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
|
logger.Error("Failed to add new remote subnet routes: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Update successful
|
// Update successful
|
||||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||||
// If this is part of a WgData structure, update it
|
// If this is part of a WgData structure, update it
|
||||||
@@ -669,11 +697,12 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
|
|
||||||
// Convert to SiteConfig
|
// Convert to SiteConfig
|
||||||
siteConfig := SiteConfig{
|
siteConfig := SiteConfig{
|
||||||
SiteId: addData.SiteId,
|
SiteId: addData.SiteId,
|
||||||
Endpoint: addData.Endpoint,
|
Endpoint: addData.Endpoint,
|
||||||
PublicKey: addData.PublicKey,
|
PublicKey: addData.PublicKey,
|
||||||
ServerIP: addData.ServerIP,
|
ServerIP: addData.ServerIP,
|
||||||
ServerPort: addData.ServerPort,
|
ServerPort: addData.ServerPort,
|
||||||
|
RemoteSubnets: addData.RemoteSubnets,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the peer to WireGuard
|
// Add the peer to WireGuard
|
||||||
@@ -684,16 +713,17 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add route for the new peer
|
// Add route for the new peer
|
||||||
err = DarwinAddRoute(siteConfig.ServerIP, "", interfaceName)
|
err = addRouteForServerIP(siteConfig.ServerIP, interfaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to add route for new peer: %v", err)
|
logger.Error("Failed to add route for new peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// err = WindowsAddRoute(siteConfig.ServerIP, "", interfaceName)
|
|
||||||
// if err != nil {
|
// Add routes for remote subnets
|
||||||
// logger.Error("Failed to add route for new peer: %v", err)
|
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
|
||||||
// return
|
logger.Error("Failed to add routes for remote subnets: %v", err)
|
||||||
// }
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Add successful
|
// Add successful
|
||||||
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
logger.Info("Successfully added peer for site %d", addData.SiteId)
|
||||||
@@ -747,14 +777,15 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Remove route for the peer
|
// Remove route for the peer
|
||||||
err = DarwinRemoveRoute(peerToRemove.ServerIP)
|
err = removeRouteForServerIP(peerToRemove.ServerIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to remove route for peer: %v", err)
|
logger.Error("Failed to remove route for peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
err = WindowsRemoveRoute(peerToRemove.ServerIP)
|
|
||||||
if err != nil {
|
// Remove routes for remote subnets
|
||||||
logger.Error("Failed to remove route for peer: %v", err)
|
if err := removeRoutesForRemoteSubnets(peerToRemove.RemoteSubnets); err != nil {
|
||||||
|
logger.Error("Failed to remove routes for remote subnets: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user