From 8cff1d37fa9135eaefc02bc9eca73b0a4953e590 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 20 Nov 2025 14:21:27 -0500 Subject: [PATCH] Updates to support updates --- olm/common.go | 13 ++++ olm/olm.go | 173 +++++++++++++++++++++++++++----------------------- olm/peer.go | 7 +- olm/route.go | 22 +++---- olm/types.go | 36 +++++------ 5 files changed, 138 insertions(+), 113 deletions(-) diff --git a/olm/common.go b/olm/common.go index 0dc8420..1f7348f 100644 --- a/olm/common.go +++ b/olm/common.go @@ -83,3 +83,16 @@ func GetNetworkSettingsJSON() (string, error) { func GetNetworkSettingsIncrementor() int { return network.GetIncrementor() } + +// stringSlicesEqual compares two string slices for equality +func stringSlicesEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/olm/olm.go b/olm/olm.go index d403ed0..386cf30 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -543,71 +543,86 @@ func StartTunnel(config TunnelConfig) { return } - // Convert to SiteConfig - siteConfig := SiteConfig{ - SiteId: updateData.SiteId, - Endpoint: updateData.Endpoint, - PublicKey: updateData.PublicKey, - ServerIP: updateData.ServerIP, - ServerPort: updateData.ServerPort, - RemoteSubnets: updateData.RemoteSubnets, + // Update the peer in WireGuard + if dev == nil { + logger.Error("WireGuard device not initialized") + return } - // Update the peer in WireGuard - if dev != nil { - // Find the existing peer to get old data - var oldRemoteSubnets string - var oldPublicKey string - for _, site := range wgData.Sites { - if site.SiteId == updateData.SiteId { - oldRemoteSubnets = site.RemoteSubnets - oldPublicKey = site.PublicKey - break - } + // Find the existing peer to merge updates with + var existingPeer *SiteConfig + var peerIndex int + for i, site := range wgData.Sites { + if site.SiteId == updateData.SiteId { + existingPeer = &wgData.Sites[i] + peerIndex = i + break } + } - // If the public key has changed, remove the old peer first - if oldPublicKey != "" && oldPublicKey != updateData.PublicKey { - logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) - if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { - logger.Error("Failed to remove old peer: %v", err) - return - } - } + if existingPeer == nil { + logger.Error("Peer with site ID %d not found", updateData.SiteId) + return + } - // Format the endpoint before updating the peer. - siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) + // Store old values for comparison + oldRemoteSubnets := existingPeer.RemoteSubnets + oldPublicKey := existingPeer.PublicKey - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { - logger.Error("Failed to update peer: %v", err) + // Create updated site config by merging with existing data + // Only update fields that are provided (non-empty/non-zero) + siteConfig := *existingPeer // Start with existing data + + if updateData.Endpoint != "" { + siteConfig.Endpoint = updateData.Endpoint + } + if updateData.PublicKey != "" { + siteConfig.PublicKey = updateData.PublicKey + } + if updateData.ServerIP != "" { + siteConfig.ServerIP = updateData.ServerIP + } + if updateData.ServerPort != 0 { + siteConfig.ServerPort = updateData.ServerPort + } + if updateData.RemoteSubnets != nil { + siteConfig.RemoteSubnets = updateData.RemoteSubnets + } + + // If the public key has changed, remove the old peer first + if siteConfig.PublicKey != oldPublicKey { + logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) + if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { + logger.Error("Failed to remove old peer: %v", err) return } - - // Remove old remote subnet routes if they changed - if oldRemoteSubnets != siteConfig.RemoteSubnets { - if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { - logger.Error("Failed to remove old remote subnet routes: %v", err) - // Continue anyway to add new routes - } - - // Add new remote subnet routes - if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { - logger.Error("Failed to add new remote subnet routes: %v", err) - return - } - } - - // Update successful - logger.Info("Successfully updated peer for site %d", updateData.SiteId) - for i := range wgData.Sites { - if wgData.Sites[i].SiteId == updateData.SiteId { - wgData.Sites[i] = siteConfig - break - } - } - } else { - logger.Error("WireGuard device not initialized") } + + // Format the endpoint before updating the peer. + siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) + + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } + + // Handle remote subnet route changes + if !stringSlicesEqual(oldRemoteSubnets, siteConfig.RemoteSubnets) { + if err := removeRoutesForRemoteSubnets(oldRemoteSubnets); err != nil { + logger.Error("Failed to remove old remote subnet routes: %v", err) + // Continue anyway to add new routes + } + + // Add new remote subnet routes + if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add new remote subnet routes: %v", err) + return + } + } + + // Update successful + logger.Info("Successfully updated peer for site %d", updateData.SiteId) + wgData.Sites[peerIndex] = siteConfig }) // Handler for adding a new peer @@ -637,31 +652,31 @@ func StartTunnel(config TunnelConfig) { } // Add the peer to WireGuard - if dev != nil { - // Format the endpoint before adding the new peer. - siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { - logger.Error("Failed to add peer: %v", err) - return - } - if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { - logger.Error("Failed to add route for new peer: %v", err) - return - } - if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { - logger.Error("Failed to add routes for remote subnets: %v", err) - return - } - - // Add successful - logger.Info("Successfully added peer for site %d", addData.SiteId) - - // Update WgData with the new peer - wgData.Sites = append(wgData.Sites, siteConfig) - } else { + if dev == nil { logger.Error("WireGuard device not initialized") + return } + // Format the endpoint before adding the new peer. + siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) + + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { + logger.Error("Failed to add route for new peer: %v", err) + return + } + if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + return + } + + // Add successful + logger.Info("Successfully added peer for site %d", addData.SiteId) + + // Update WgData with the new peer + wgData.Sites = append(wgData.Sites, siteConfig) }) // Handler for removing a peer diff --git a/olm/peer.go b/olm/peer.go index 1f8a5f4..6134d8f 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -34,10 +34,9 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes allowedIPs = append(allowedIPs, allowedIpStr) // If we have anything in remoteSubnets, add those as well - if siteConfig.RemoteSubnets != "" { - // Split remote subnets by comma and add each one - remoteSubnets := strings.Split(siteConfig.RemoteSubnets, ",") - for _, subnet := range remoteSubnets { + if len(siteConfig.RemoteSubnets) > 0 { + // Add each remote subnet + for _, subnet := range siteConfig.RemoteSubnets { subnet = strings.TrimSpace(subnet) if subnet != "" { allowedIPs = append(allowedIPs, subnet) diff --git a/olm/route.go b/olm/route.go index cc991fc..439d929 100644 --- a/olm/route.go +++ b/olm/route.go @@ -268,15 +268,14 @@ func removeRouteForNetworkConfig(destination string) error { return nil } -// addRoutesForRemoteSubnets adds routes for each comma-separated CIDR in RemoteSubnets -func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error { - if remoteSubnets == "" { +// addRoutesForRemoteSubnets adds routes for each subnet in RemoteSubnets +func addRoutesForRemoteSubnets(remoteSubnets []string, interfaceName string) error { + if len(remoteSubnets) == 0 { return nil } - // Split remote subnets by comma and add routes for each one - subnets := strings.Split(remoteSubnets, ",") - for _, subnet := range subnets { + // Add routes for each subnet + for _, subnet := range remoteSubnets { subnet = strings.TrimSpace(subnet) if subnet == "" { continue @@ -314,15 +313,14 @@ func addRoutesForRemoteSubnets(remoteSubnets, interfaceName string) error { return nil } -// removeRoutesForRemoteSubnets removes routes for each comma-separated CIDR in RemoteSubnets -func removeRoutesForRemoteSubnets(remoteSubnets string) error { - if remoteSubnets == "" { +// removeRoutesForRemoteSubnets removes routes for each subnet in RemoteSubnets +func removeRoutesForRemoteSubnets(remoteSubnets []string) error { + if len(remoteSubnets) == 0 { return nil } - // Split remote subnets by comma and remove routes for each one - subnets := strings.Split(remoteSubnets, ",") - for _, subnet := range subnets { + // Remove routes for each subnet + for _, subnet := range remoteSubnets { subnet = strings.TrimSpace(subnet) if subnet == "" { continue diff --git a/olm/types.go b/olm/types.go index 4ccdb8d..b7fb05a 100644 --- a/olm/types.go +++ b/olm/types.go @@ -6,12 +6,12 @@ type WgData struct { } type SiteConfig struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access } type HolePunchMessage struct { @@ -41,22 +41,22 @@ type PeerAction struct { // UpdatePeerData represents the data needed to update a peer type UpdatePeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint,omitempty"` + PublicKey string `json:"publicKey,omitempty"` + ServerIP string `json:"serverIP,omitempty"` + ServerPort uint16 `json:"serverPort,omitempty"` + RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access } // AddPeerData represents the data needed to add a peer type AddPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - ServerPort uint16 `json:"serverPort"` - RemoteSubnets string `json:"remoteSubnets,omitempty"` // optional, comma-separated list of subnets that this site can access + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` + ServerPort uint16 `json:"serverPort"` + RemoteSubnets []string `json:"remoteSubnets,omitempty"` // optional, array of subnets that this site can access } // RemovePeerData represents the data needed to remove a peer