From 45ef6e52794ac567dbfbffe3f39e9e59535c8dd8 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 1 Dec 2025 21:28:14 -0500 Subject: [PATCH] Migrate peer monitor into peer manager Former-commit-id: 29f0babf07d1c30116cc07caef77a5bf16f0ef71 --- olm/olm.go | 72 +++++----- peers/manager.go | 126 +++++++++++++++--- .../monitor/monitor.go | 39 +----- {peermonitor => peers/monitor}/wgtester.go | 2 +- peers/types.go | 1 + peers/{peer.go => wg.go} | 40 +----- 6 files changed, 154 insertions(+), 126 deletions(-) rename peermonitor/peermonitor.go => peers/monitor/monitor.go (94%) rename {peermonitor => peers/monitor}/wgtester.go (99%) rename peers/{peer.go => wg.go} (65%) diff --git a/olm/olm.go b/olm/olm.go index da04daf..6401984 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -20,7 +20,6 @@ import ( olmDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/dns" dnsOverride "github.com/fosrl/olm/dns/override" - "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/peers" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" @@ -32,7 +31,6 @@ var ( privateKey wgtypes.Key connected bool dev *device.Device - wgData WgData uapiListener net.Listener tdev tun.Device middleDev *olmDevice.MiddleDevice @@ -43,7 +41,6 @@ var ( tunnelRunning bool sharedBind *bind.SharedBind holePunchManager *holepunch.Manager - peerMonitor *peermonitor.PeerMonitor globalConfig GlobalConfig tunnelConfig TunnelConfig globalCtx context.Context @@ -269,6 +266,8 @@ func StartTunnel(config TunnelConfig) { olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) + var wgData WgData + if connected { logger.Info("Already connected. Ignoring new connection request.") return @@ -398,17 +397,28 @@ func StartTunnel(config TunnelConfig) { wsClientForMonitor = olm } - peerMonitor = peermonitor.NewPeerMonitor( - func(siteID int, connected bool, rtt time.Duration) { + // Create peer manager with integrated peer monitoring + peerManager = peers.NewPeerManager(peers.PeerManagerConfig{ + Device: dev, + DNSProxy: dnsProxy, + InterfaceName: interfaceName, + PrivateKey: privateKey, + MiddleDev: middleDev, + LocalIP: interfaceIP, + SharedBind: sharedBind, + WSClient: wsClientForMonitor, + StatusCallback: func(siteID int, connected bool, rtt time.Duration) { // Find the site config to get endpoint information var endpoint string var isRelay bool for _, site := range wgData.Sites { if site.SiteId == siteID { - endpoint = site.Endpoint - // TODO: We'll need to track relay status separately - // For now, assume not using relay unless we get relay data - isRelay = !config.Holepunch + if site.RelayEndpoint != "" { + endpoint = site.RelayEndpoint + } else { + endpoint = site.Endpoint + } + isRelay = site.RelayEndpoint != "" break } } @@ -419,43 +429,41 @@ func StartTunnel(config TunnelConfig) { logger.Warn("Peer %d is disconnected", siteID) } }, - wsClientForMonitor, - middleDev, - interfaceIP, - sharedBind, // Pass sharedBind for holepunch testing - ) - - peerManager = peers.NewPeerManager(dev, peerMonitor, dnsProxy, interfaceName, privateKey) + }) for i := range wgData.Sites { site := wgData.Sites[i] - apiServer.UpdatePeerStatus(site.SiteId, false, 0, site.Endpoint, false) + var siteEndpoint string + // here we are going to take the relay endpoint if it exists which means we requested a relay for this peer + if site.RelayEndpoint != "" { + siteEndpoint = site.RelayEndpoint + } else { + siteEndpoint = site.Endpoint + } + apiServer.UpdatePeerStatus(site.SiteId, false, 0, siteEndpoint, false) - if err := peerManager.AddPeer(site, endpoint); err != nil { + if err := peerManager.AddPeer(site, siteEndpoint); err != nil { logger.Error("Failed to add peer: %v", err) return } - // Add holepunch monitoring for this endpoint if holepunching is enabled - if config.Holepunch { - peerMonitor.AddHolepunchEndpoint(site.SiteId, site.Endpoint) - } - logger.Info("Configured peer %s", site.PublicKey) } - peerMonitor.SetHolepunchStatusCallback(func(siteID int, endpoint string, connected bool, rtt time.Duration) { + peerManager.SetHolepunchStatusCallback(func(siteID int, endpoint string, connected bool, rtt time.Duration) { // This callback is for additional handling if needed // The PeerMonitor already logs status changes logger.Info("+++++++++++++++++++++++++ holepunch monitor callback for site %d, endpoint %s, connected: %v, rtt: %v", siteID, endpoint, connected, rtt) }) - peerMonitor.Start() + peerManager.Start() - // Set up DNS override to use our DNS proxy - if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil { - logger.Error("Failed to setup DNS override: %v", err) - return + if config.OverrideDNS { + // Set up DNS override to use our DNS proxy + if err := dnsOverride.SetupDNSOverride(interfaceName, dnsProxy); err != nil { + logger.Error("Failed to setup DNS override: %v", err) + return + } } if err := dnsProxy.Start(); err != nil { @@ -906,12 +914,8 @@ func Close() { updateRegister = nil } - if peerMonitor != nil { - peerMonitor.Close() // Close() also calls Stop() internally - peerMonitor = nil - } - if peerManager != nil { + peerManager.Close() // Close() also calls Stop() internally peerManager = nil } diff --git a/peers/manager.go b/peers/manager.go index 7b18350..12631b0 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -3,22 +3,50 @@ package peers import ( "fmt" "net" + "strconv" "strings" "sync" + "time" + "github.com/fosrl/newt/bind" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/network" + olmDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/dns" - "github.com/fosrl/olm/peermonitor" + "github.com/fosrl/olm/peers/monitor" + "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +// PeerStatusCallback is called when a peer's connection status changes +type PeerStatusCallback func(siteID int, connected bool, rtt time.Duration) + +// HolepunchStatusCallback is called when holepunch connection status changes +// This is an alias for monitor.HolepunchStatusCallback +type HolepunchStatusCallback = monitor.HolepunchStatusCallback + +// PeerManagerConfig contains the configuration for creating a PeerManager +type PeerManagerConfig struct { + Device *device.Device + DNSProxy *dns.DNSProxy + InterfaceName string + PrivateKey wgtypes.Key + // For peer monitoring + MiddleDev *olmDevice.MiddleDevice + LocalIP string + SharedBind *bind.SharedBind + // WSClient is optional - if nil, relay messages won't be sent + WSClient *websocket.Client + // StatusCallback is called when peer connection status changes + StatusCallback PeerStatusCallback +} + type PeerManager struct { mu sync.RWMutex device *device.Device peers map[int]SiteConfig - peerMonitor *peermonitor.PeerMonitor + peerMonitor *monitor.PeerMonitor dnsProxy *dns.DNSProxy interfaceName string privateKey wgtypes.Key @@ -28,19 +56,38 @@ type PeerManager struct { // allowedIPClaims tracks all peers that claim each allowed IP // key is the CIDR string, value is a set of siteIds that want this IP allowedIPClaims map[string]map[int]bool + // statusCallback is called when peer connection status changes + statusCallback PeerStatusCallback } -func NewPeerManager(dev *device.Device, monitor *peermonitor.PeerMonitor, dnsProxy *dns.DNSProxy, interfaceName string, privateKey wgtypes.Key) *PeerManager { - return &PeerManager{ - device: dev, +// NewPeerManager creates a new PeerManager with an internal PeerMonitor +func NewPeerManager(config PeerManagerConfig) *PeerManager { + pm := &PeerManager{ + device: config.Device, peers: make(map[int]SiteConfig), - peerMonitor: monitor, - dnsProxy: dnsProxy, - interfaceName: interfaceName, - privateKey: privateKey, + dnsProxy: config.DNSProxy, + interfaceName: config.InterfaceName, + privateKey: config.PrivateKey, allowedIPOwners: make(map[string]int), allowedIPClaims: make(map[string]map[int]bool), + statusCallback: config.StatusCallback, } + + // Create the peer monitor + pm.peerMonitor = monitor.NewPeerMonitor( + func(siteID int, connected bool, rtt time.Duration) { + // Call the external status callback if set + if pm.statusCallback != nil { + pm.statusCallback(siteID, connected, rtt) + } + }, + config.WSClient, + config.MiddleDev, + config.LocalIP, + config.SharedBind, + ) + + return pm } func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) { @@ -86,7 +133,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig, endpoint string) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint, pm.peerMonitor); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil { return err } @@ -104,6 +151,16 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig, endpoint string) error { pm.dnsProxy.AddDNSRecord(alias.Alias, address) } + monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] + monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port + + err := pm.peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer) + if err != nil { + logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) + } else { + logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) + } + pm.peers[siteConfig.SiteId] = siteConfig return nil } @@ -117,7 +174,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error { return fmt.Errorf("peer with site ID %d not found", siteId) } - if err := RemovePeer(pm.device, siteId, peer.PublicKey, pm.peerMonitor); err != nil { + if err := RemovePeer(pm.device, siteId, peer.PublicKey); err != nil { return err } @@ -167,12 +224,16 @@ func (pm *PeerManager) RemovePeer(siteId int) error { ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) wgConfig := promotedPeer wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, promotedPeer.Endpoint, pm.peerMonitor); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, promotedPeer.Endpoint); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } } + // Stop monitoring this peer + pm.peerMonitor.RemovePeer(siteId) + logger.Info("Stopped monitoring for site %d", siteId) + delete(pm.peers, siteId) return nil } @@ -188,7 +249,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error // If public key changed, remove old peer first if siteConfig.PublicKey != oldPeer.PublicKey { - if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey, pm.peerMonitor); err != nil { + if err := RemovePeer(pm.device, siteConfig.SiteId, oldPeer.PublicKey); err != nil { logger.Error("Failed to remove old peer: %v", err) } } @@ -237,7 +298,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint, pm.peerMonitor); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, endpoint); err != nil { return err } @@ -247,7 +308,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig, endpoint string) error promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedWgConfig := promotedPeer promotedWgConfig.AllowedIps = promotedOwnedIPs - if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, promotedPeer.Endpoint, pm.peerMonitor); err != nil { + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, promotedPeer.Endpoint); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } @@ -399,7 +460,7 @@ func (pm *PeerManager) addAllowedIp(siteId int, ip string) error { // Only update WireGuard if we own this IP if pm.allowedIPOwners[ip] == siteId { - if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint, pm.peerMonitor); err != nil { + if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint); err != nil { return err } } @@ -439,14 +500,14 @@ func (pm *PeerManager) removeAllowedIp(siteId int, cidr string) error { newOwner, promoted := pm.releaseAllowedIP(siteId, cidr) // Update WireGuard for this peer (to remove the IP from its config) - if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint, pm.peerMonitor); err != nil { + if err := ConfigurePeer(pm.device, peer, pm.privateKey, peer.Endpoint); err != nil { return err } // If another peer was promoted to owner, update their WireGuard config if promoted && newOwner >= 0 { if newOwnerPeer, exists := pm.peers[newOwner]; exists { - if err := ConfigurePeer(pm.device, newOwnerPeer, pm.privateKey, newOwnerPeer.Endpoint, pm.peerMonitor); err != nil { + if err := ConfigurePeer(pm.device, newOwnerPeer, pm.privateKey, newOwnerPeer.Endpoint); err != nil { logger.Error("Failed to promote peer %d for IP %s: %v", newOwner, cidr, err) } else { logger.Info("Promoted peer %d to owner of IP %s", newOwner, cidr) @@ -626,3 +687,32 @@ endpoint=%s:21820`, peer.PublicKey, formattedEndpoint) logger.Info("Adjusted peer %d to point to relay!\n", siteId) } + +// Start starts the peer monitor +func (pm *PeerManager) Start() { + if pm.peerMonitor != nil { + pm.peerMonitor.Start() + } +} + +// Stop stops the peer monitor +func (pm *PeerManager) Stop() { + if pm.peerMonitor != nil { + pm.peerMonitor.Stop() + } +} + +// Close stops the peer monitor and cleans up resources +func (pm *PeerManager) Close() { + if pm.peerMonitor != nil { + pm.peerMonitor.Close() + pm.peerMonitor = nil + } +} + +// SetHolepunchStatusCallback sets the callback for holepunch status changes +func (pm *PeerManager) SetHolepunchStatusCallback(callback HolepunchStatusCallback) { + if pm.peerMonitor != nil { + pm.peerMonitor.SetHolepunchStatusCallback(callback) + } +} diff --git a/peermonitor/peermonitor.go b/peers/monitor/monitor.go similarity index 94% rename from peermonitor/peermonitor.go rename to peers/monitor/monitor.go index 59856a6..9a02408 100644 --- a/peermonitor/peermonitor.go +++ b/peers/monitor/monitor.go @@ -1,4 +1,4 @@ -package peermonitor +package monitor import ( "context" @@ -31,19 +31,9 @@ type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) // HolepunchStatusCallback is called when holepunch connection status changes type HolepunchStatusCallback func(siteID int, endpoint string, connected bool, rtt time.Duration) -// WireGuardConfig holds the WireGuard configuration for a peer -type WireGuardConfig struct { - SiteID int - PublicKey string - ServerIP string - Endpoint string - PrimaryRelay string // The primary relay endpoint -} - // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { monitors map[int]*Client - configs map[int]*WireGuardConfig callback PeerMonitorCallback mutex sync.Mutex running bool @@ -79,7 +69,6 @@ func NewPeerMonitor(callback PeerMonitorCallback, wsClient *websocket.Client, mi ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - configs: make(map[int]*WireGuardConfig), callback: callback, interval: 1 * time.Second, // Default check interval timeout: 2500 * time.Millisecond, @@ -149,7 +138,7 @@ func (pm *PeerMonitor) SetMaxAttempts(attempts int) { } // AddPeer adds a new peer to monitor -func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error { +func (pm *PeerMonitor) AddPeer(siteID int, endpoint string) error { pm.mutex.Lock() defer pm.mutex.Unlock() @@ -168,7 +157,8 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC client.SetMaxAttempts(pm.maxAttempts) pm.monitors[siteID] = client - pm.configs[siteID] = wgConfig + pm.holepunchEndpoints[siteID] = endpoint + pm.holepunchStatus[siteID] = false // Initially unknown/disconnected if pm.running { if err := client.StartMonitor(func(status ConnectionStatus) { @@ -192,7 +182,6 @@ func (pm *PeerMonitor) removePeerUnlocked(siteID int) { client.StopMonitor() client.Close() delete(pm.monitors, siteID) - delete(pm.configs, siteID) } // RemovePeer stops monitoring a peer and removes it from the monitor @@ -289,26 +278,6 @@ func (pm *PeerMonitor) SetHolepunchStatusCallback(callback HolepunchStatusCallba pm.holepunchStatusCallback = callback } -// AddHolepunchEndpoint adds an endpoint to monitor via holepunch magic packets -func (pm *PeerMonitor) AddHolepunchEndpoint(siteID int, endpoint string) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.holepunchEndpoints[siteID] = endpoint - pm.holepunchStatus[siteID] = false // Initially unknown/disconnected - logger.Info("Added holepunch monitoring for site %d at %s", siteID, endpoint) -} - -// RemoveHolepunchEndpoint removes an endpoint from holepunch monitoring -func (pm *PeerMonitor) RemoveHolepunchEndpoint(siteID int) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - delete(pm.holepunchEndpoints, siteID) - delete(pm.holepunchStatus, siteID) - logger.Info("Removed holepunch monitoring for site %d", siteID) -} - // startHolepunchMonitor starts the holepunch connection monitoring // Note: This function assumes the mutex is already held by the caller (called from Start()) func (pm *PeerMonitor) startHolepunchMonitor() error { diff --git a/peermonitor/wgtester.go b/peers/monitor/wgtester.go similarity index 99% rename from peermonitor/wgtester.go rename to peers/monitor/wgtester.go index 05ce99a..15bf025 100644 --- a/peermonitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -1,4 +1,4 @@ -package peermonitor +package monitor import ( "context" diff --git a/peers/types.go b/peers/types.go index f984ba6..49d0924 100644 --- a/peers/types.go +++ b/peers/types.go @@ -10,6 +10,7 @@ type PeerAction struct { type SiteConfig struct { SiteId int `json:"siteId"` Endpoint string `json:"endpoint,omitempty"` + RelayEndpoint string `json:"relayEndpoint,omitempty"` PublicKey string `json:"publicKey,omitempty"` ServerIP string `json:"serverIP,omitempty"` ServerPort uint16 `json:"serverPort,omitempty"` diff --git a/peers/peer.go b/peers/wg.go similarity index 65% rename from peers/peer.go rename to peers/wg.go index 116d199..4bb91f3 100644 --- a/peers/peer.go +++ b/peers/wg.go @@ -2,19 +2,16 @@ package peers import ( "fmt" - "net" - "strconv" "strings" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" - "github.com/fosrl/olm/peermonitor" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) // ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string, peerMonitor *peermonitor.PeerMonitor) error { +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, endpoint string) error { siteHost, err := util.ResolveDomain(formatEndpoint(siteConfig.Endpoint)) if err != nil { return fmt.Errorf("failed to resolve endpoint for site %d: %v", siteConfig.SiteId, err) @@ -68,38 +65,11 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes return fmt.Errorf("failed to configure WireGuard peer: %v", err) } - // Set up peer monitoring - if peerMonitor != nil { - monitorAddress := strings.Split(siteConfig.ServerIP, "/")[0] - monitorPeer := net.JoinHostPort(monitorAddress, strconv.Itoa(int(siteConfig.ServerPort+1))) // +1 for the monitor port - logger.Debug("Setting up peer monitor for site %d at %s", siteConfig.SiteId, monitorPeer) - logger.Debug("Resolving primary relay %s for peer", endpoint) - primaryRelay, err := util.ResolveDomain(endpoint) // Using global endpoint variable - if err != nil { - logger.Warn("Failed to resolve primary relay endpoint for peer: %v", err) - } - - wgConfig := &peermonitor.WireGuardConfig{ - SiteID: siteConfig.SiteId, - PublicKey: util.FixKey(siteConfig.PublicKey), - ServerIP: strings.Split(siteConfig.ServerIP, "/")[0], - Endpoint: siteConfig.Endpoint, - PrimaryRelay: primaryRelay, - } - - err = peerMonitor.AddPeer(siteConfig.SiteId, monitorPeer, wgConfig) - if err != nil { - logger.Warn("Failed to setup monitoring for site %d: %v", siteConfig.SiteId, err) - } else { - logger.Info("Started monitoring for site %d at %s", siteConfig.SiteId, monitorPeer) - } - } - return nil } // RemovePeer removes a peer from the WireGuard device -func RemovePeer(dev *device.Device, siteId int, publicKey string, peerMonitor *peermonitor.PeerMonitor) error { +func RemovePeer(dev *device.Device, siteId int, publicKey string) error { // Construct WireGuard config to remove the peer var configBuilder strings.Builder configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) @@ -113,12 +83,6 @@ func RemovePeer(dev *device.Device, siteId int, publicKey string, peerMonitor *p return fmt.Errorf("failed to remove WireGuard peer: %v", err) } - // Stop monitoring this peer - if peerMonitor != nil { - peerMonitor.RemovePeer(siteId) - logger.Info("Stopped monitoring for site %d", siteId) - } - return nil }