Remove the old peer when updating new peer

This commit is contained in:
Owen
2025-09-28 12:25:33 -07:00
parent 1054d70192
commit 74b166e82f

91
main.go
View File

@@ -53,7 +53,6 @@ func formatEndpoint(endpoint string) string {
return endpoint return endpoint
} }
func main() { func main() {
// Check if we're running as a Windows service // Check if we're running as a Windows service
if isWindowsService() { if isWindowsService() {
@@ -598,30 +597,47 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
fileUAPI, err := func() (*os.File, error) { fileUAPI, err := func() (*os.File, error) {
if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" { if uapiFdStr := os.Getenv(ENV_WG_UAPI_FD); uapiFdStr != "" {
fd, err := strconv.ParseUint(uapiFdStr, 10, 32) fd, err := strconv.ParseUint(uapiFdStr, 10, 32)
if err != nil { return nil, err } if err != nil {
return nil, err
}
return os.NewFile(uintptr(fd), ""), nil return os.NewFile(uintptr(fd), ""), nil
} }
return uapiOpen(interfaceName) return uapiOpen(interfaceName)
}() }()
if err != nil { logger.Error("UAPI listen error: %v", err); os.Exit(1); return } if err != nil {
logger.Error("UAPI listen error: %v", err)
os.Exit(1)
return
}
dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: ")) dev = device.NewDevice(tdev, NewFixedPortBind(uint16(sourcePort)), device.NewLogger(mapToWireGuardLogLevel(loggerLevel), "wireguard: "))
uapiListener, err = uapiListen(interfaceName, fileUAPI) uapiListener, err = uapiListen(interfaceName, fileUAPI)
if err != nil { logger.Error("Failed to listen on uapi socket: %v", err); os.Exit(1) } if err != nil {
logger.Error("Failed to listen on uapi socket: %v", err)
os.Exit(1)
}
go func() { go func() {
for { for {
conn, err := uapiListener.Accept() conn, err := uapiListener.Accept()
if err != nil { return } if err != nil {
return
}
go dev.IpcHandle(conn) go dev.IpcHandle(conn)
} }
}() }()
logger.Info("UAPI listener started") logger.Info("UAPI listener started")
if err = dev.Up(); err != nil { logger.Error("Failed to bring up WireGuard device: %v", err) } if err = dev.Up(); err != nil {
if err = ConfigureInterface(interfaceName, wgData); err != nil { logger.Error("Failed to configure interface: %v", err) } logger.Error("Failed to bring up WireGuard device: %v", err)
if httpServer != nil { httpServer.SetTunnelIP(wgData.TunnelIP) } }
if err = ConfigureInterface(interfaceName, wgData); err != nil {
logger.Error("Failed to configure interface: %v", err)
}
if httpServer != nil {
httpServer.SetTunnelIP(wgData.TunnelIP)
}
peerMonitor = peermonitor.NewPeerMonitor( peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) { func(siteID int, connected bool, rtt time.Duration) {
@@ -661,9 +677,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
// Format the endpoint before configuring the peer. // Format the endpoint before configuring the peer.
site.Endpoint = formatEndpoint(site.Endpoint) site.Endpoint = formatEndpoint(site.Endpoint)
if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil { logger.Error("Failed to configure peer: %v", err); return } if err := ConfigurePeer(dev, *site, privateKey, endpoint); err != nil {
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for peer: %v", err); return } logger.Error("Failed to configure peer: %v", err)
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return } return
}
if err := addRouteForServerIP(site.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(site.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
logger.Info("Configured peer %s", site.PublicKey) logger.Info("Configured peer %s", site.PublicKey)
} }
@@ -702,19 +727,33 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
// Update the peer in WireGuard // Update the peer in WireGuard
if dev != nil { if dev != nil {
// Find the existing peer to get old RemoteSubnets // Find the existing peer to get old data
var oldRemoteSubnets string var oldRemoteSubnets string
var oldPublicKey string
for _, site := range wgData.Sites { for _, site := range wgData.Sites {
if site.SiteId == updateData.SiteId { if site.SiteId == updateData.SiteId {
oldRemoteSubnets = site.RemoteSubnets oldRemoteSubnets = site.RemoteSubnets
oldPublicKey = site.PublicKey
break break
} }
} }
// If the public key has changed, remove the old peer first
if oldPublicKey != "" && oldPublicKey != updateData.PublicKey {
logger.Info("Public key changed for site %d, removing old peer with key %s", updateData.SiteId, oldPublicKey)
if err := RemovePeer(dev, updateData.SiteId, oldPublicKey); err != nil {
logger.Error("Failed to remove old peer: %v", err)
return
}
}
// Format the endpoint before updating the peer. // Format the endpoint before updating the peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to update peer: %v", err); return } if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
logger.Error("Failed to update peer: %v", err)
return
}
// Remove old remote subnet routes if they changed // Remove old remote subnet routes if they changed
if oldRemoteSubnets != siteConfig.RemoteSubnets { if oldRemoteSubnets != siteConfig.RemoteSubnets {
@@ -733,7 +772,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
// Update successful // Update successful
logger.Info("Successfully updated peer for site %d", updateData.SiteId) logger.Info("Successfully updated peer for site %d", updateData.SiteId)
for i := range wgData.Sites { for i := range wgData.Sites {
if wgData.Sites[i].SiteId == updateData.SiteId { wgData.Sites[i] = siteConfig; break } if wgData.Sites[i].SiteId == updateData.SiteId {
wgData.Sites[i] = siteConfig
break
}
} }
} else { } else {
logger.Error("WireGuard device not initialized") logger.Error("WireGuard device not initialized")
@@ -771,9 +813,18 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
// Format the endpoint before adding the new peer. // Format the endpoint before adding the new peer.
siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint) siteConfig.Endpoint = formatEndpoint(siteConfig.Endpoint)
if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil { logger.Error("Failed to add peer: %v", err); return } if err := ConfigurePeer(dev, siteConfig, privateKey, endpoint); err != nil {
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil { logger.Error("Failed to add route for new peer: %v", err); return } logger.Error("Failed to add peer: %v", err)
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil { logger.Error("Failed to add routes for remote subnets: %v", err); return } return
}
if err := addRouteForServerIP(siteConfig.ServerIP, interfaceName); err != nil {
logger.Error("Failed to add route for new peer: %v", err)
return
}
if err := addRoutesForRemoteSubnets(siteConfig.RemoteSubnets, interfaceName); err != nil {
logger.Error("Failed to add routes for remote subnets: %v", err)
return
}
// Add successful // Add successful
logger.Info("Successfully added peer for site %d", addData.SiteId) logger.Info("Successfully added peer for site %d", addData.SiteId)