From dd00289f8e6e094f63e46754d4fde8e77985b3a5 Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 28 Sep 2025 12:25:33 -0700 Subject: [PATCH] Remove the old peer when updating new peer Former-commit-id: 74b166e82f18d421356cbd9ebfe89576576ff99d --- main.go | 91 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 71 insertions(+), 20 deletions(-) diff --git a/main.go b/main.go index 9d0ff10..82fbd8e 100644 --- a/main.go +++ b/main.go @@ -53,7 +53,6 @@ func formatEndpoint(endpoint string) string { return endpoint } - func main() { // Check if we're running as a Windows service if isWindowsService() { @@ -598,30 +597,47 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { fileUAPI, err := func() (*os.File, error) { if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" { fd, err := strconv.ParseUint(uapiFdStr, 10, 32) - if err != nil { return nil, err } + if err != nil { + return nil, err + } return os.NewFile(uintptr(fd), ""), nil } return uapiOpen(interfaceName) }() - if err != nil { logger.Error("UAPI listen error: %v", err); os.Exit(1); return } + if err != nil { + logger.Error("UAPI listen error: %v", err) + os.Exit(1) + return + } dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) - + uapiListener, err = uapiListen(interfaceName, fileUAPI) - if err != nil { logger.Error("Failed to listen on uapi socket: %v", err); os.Exit(1) } + if err != nil { + logger.Error("Failed to listen on uapi socket: %v", err) + os.Exit(1) + } go func() { for { conn, err := uapiListener.Accept() - if err != nil { return } + if err != nil { + return + } go dev.IpcHandle(conn) } }() logger.Info("UAPI listener started") - if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) } - if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) } - if httpServer != nil { httpServer.SetTunnelIP(wgData.TunnelIP) } + if err = dev.Up(); err != nil { + logger.Error("Failed to bring up WireGuard device: %v", err) + } + if err = ConfigureInterface(interfaceName, wgData); err != nil { + logger.Error("Failed to configure interface: %v", err) + } + if httpServer != nil { + httpServer.SetTunnelIP(wgData.TunnelIP) + } peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { @@ -661,9 +677,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { // Format the endpoint before configuring the peer. site.Endpoint = formatEndpoint(site.Endpoint) - if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err); return } - if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for peer: %v", err); return } - if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return } + if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { + logger.Error("Failed to configure peer: %v", err) + return + } + if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { + logger.Error("Failed to add route for peer: %v", err) + return + } + if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + return + } logger.Info("Configured peer %s", site.PublicKey) } @@ -702,19 +727,33 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { // Update the peer in WireGuard if dev != nil { - // Find the existing peer to get old RemoteSubnets + // Find the existing peer to get old data var oldRemoteSubnets string + var oldPublicKey string for _, site := range wgData.Sites { if site.SiteId == updateData.SiteId { oldRemoteSubnets = site.RemoteSubnets + oldPublicKey = site.PublicKey break } } - + + // If the public key has changed, remove the old peer first + if oldPublicKey != "" && oldPublicKey != updateData.PublicKey { + logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey) + if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil { + logger.Error("Failed to remove old peer: %v", err) + return + } + } + // Format the endpoint before updating the peer. siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err); return } + + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to update peer: %v", err) + return + } // Remove old remote subnet routes if they changed if oldRemoteSubnets != siteConfig.RemoteSubnets { @@ -733,7 +772,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { // Update successful logger.Info("Successfully updated peer for site %d", updateData.SiteId) for i := range wgData.Sites { - if wgData.Sites[i].SiteId == updateData.SiteId { wgData.Sites[i] = siteConfig; break } + if wgData.Sites[i].SiteId == updateData.SiteId { + wgData.Sites[i] = siteConfig + break + } } } else { logger.Error("WireGuard device not initialized") @@ -771,9 +813,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { // Format the endpoint before adding the new peer. siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) - if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err); return } - if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err); return } - if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return } + if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { + logger.Error("Failed to add peer: %v", err) + return + } + if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { + logger.Error("Failed to add route for new peer: %v", err) + return + } + if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { + logger.Error("Failed to add routes for remote subnets: %v", err) + return + } // Add successful logger.Info("Successfully added peer for site %d", addData.SiteId)