diff --git a/peers/manager.go b/peers/manager.go index 9cc1e75..b371aa1 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/fosrl/newt/bind" "github.com/fosrl/newt/logger" @@ -54,6 +55,9 @@ type PeerManager struct { publicDNS []string PersistentKeepalive int + + routeOptimizerStop chan struct{} + optimizerTrigger chan struct{} } // NewPeerManager creates a new PeerManager with an internal PeerMonitor @@ -80,6 +84,8 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager { config.PublicDNS, ) + pm.optimizerTrigger = make(chan struct{}, 1) + return pm } @@ -856,10 +862,12 @@ func (pm *PeerManager) Start() { if pm.peerMonitor != nil { pm.peerMonitor.Start() } + pm.startRouteOptimizer() } // Stop stops the peer monitor func (pm *PeerManager) Stop() { + pm.stopRouteOptimizer() if pm.peerMonitor != nil { pm.peerMonitor.Stop() } @@ -867,6 +875,7 @@ func (pm *PeerManager) Stop() { // Close stops the peer monitor and cleans up resources func (pm *PeerManager) Close() { + pm.stopRouteOptimizer() if pm.peerMonitor != nil { pm.peerMonitor.Close() pm.peerMonitor = nil @@ -928,3 +937,166 @@ endpoint=%s`, util.FixKey(peer.PublicKey), endpoint) logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint) return nil } + +// isBetterConnection returns true if connection quality (a) is better than (b). +// Priority: connected > disconnected, then direct > relayed, then lower RTT. +func isBetterConnection(aConn bool, aRelay bool, aRTT time.Duration, + bConn bool, bRelay bool, bRTT time.Duration) bool { + if aConn != bConn { + return aConn // connected beats disconnected + } + if !aConn { + return false // both offline, no preference + } + if aRelay != bRelay { + return !aRelay // direct beats relayed + } + // Same connectivity class: prefer lower RTT + if aRTT == 0 { + return false // unknown RTT, don't displace + } + if bRTT == 0 { + return true // current has no RTT data, prefer known + } + return aRTT < bRTT +} + +// selectBestOwner returns the siteId of the best site to own the given IP, +// based on connection quality. Must be called with pm.mu held. +func (pm *PeerManager) selectBestOwner(claims map[int]bool) int { + bestSiteId := -1 + var bestConn, bestRelay bool + var bestRTT time.Duration + + for siteId := range claims { + conn, relay, rtt := pm.peerMonitor.GetConnectionQuality(siteId) + if bestSiteId < 0 || isBetterConnection(conn, relay, rtt, bestConn, bestRelay, bestRTT) { + bestSiteId = siteId + bestConn = conn + bestRelay = relay + bestRTT = rtt + } + } + return bestSiteId +} + +// getWireGuardAllowedIPs returns the full set of IPs that should be in WireGuard +// for a peer: server IP /32 plus all shared IPs it currently owns. +// Must be called with pm.mu held. +func (pm *PeerManager) getWireGuardAllowedIPs(siteId int) []string { + peer, exists := pm.peers[siteId] + if !exists { + return nil + } + serverIP := strings.Split(peer.ServerIP, "/")[0] + "/32" + ips := []string{serverIP} + for cidr, owner := range pm.allowedIPOwners { + if owner == siteId { + ips = append(ips, cidr) + } + } + return ips +} + +// transferOwnership moves WireGuard ownership of cidr from fromSiteId to toSiteId. +// Must be called with pm.mu held. +func (pm *PeerManager) transferOwnership(cidr string, fromSiteId int, toSiteId int) error { + // Update owner map first + pm.allowedIPOwners[cidr] = toSiteId + + // Remove cidr from old owner's WireGuard allowed IPs + if fromPeer, exists := pm.peers[fromSiteId]; exists { + remaining := pm.getWireGuardAllowedIPs(fromSiteId) // cidr is no longer in owners, so it won't appear here + if err := RemoveAllowedIP(pm.device, fromPeer.PublicKey, remaining); err != nil { + // Revert + pm.allowedIPOwners[cidr] = fromSiteId + return fmt.Errorf("remove IP %s from site %d: %v", cidr, fromSiteId, err) + } + } + + // Add cidr to new owner's WireGuard allowed IPs + if toPeer, exists := pm.peers[toSiteId]; exists { + if err := AddAllowedIP(pm.device, toPeer.PublicKey, cidr); err != nil { + return fmt.Errorf("add IP %s to site %d: %v", cidr, toSiteId, err) + } + } + + return nil +} + +// optimizeRoutes evaluates all shared IPs and reassigns ownership to the best site. +func (pm *PeerManager) optimizeRoutes() { + pm.mu.Lock() + defer pm.mu.Unlock() + + for cidr, claims := range pm.allowedIPClaims { + if len(claims) <= 1 { + continue // No competition, nothing to optimize + } + + currentOwner, hasOwner := pm.allowedIPOwners[cidr] + bestOwner := pm.selectBestOwner(claims) + + if bestOwner < 0 { + continue + } + if hasOwner && currentOwner == bestOwner { + continue // Already on the best site + } + + if !hasOwner { + // No current owner, just assign + pm.allowedIPOwners[cidr] = bestOwner + if toPeer, exists := pm.peers[bestOwner]; exists { + if err := AddAllowedIP(pm.device, toPeer.PublicKey, cidr); err != nil { + logger.Error("Failed to assign IP %s to site %d: %v", cidr, bestOwner, err) + } + } + continue + } + + logger.Info("Route optimizer: moving %s from site %d to site %d", cidr, currentOwner, bestOwner) + if err := pm.transferOwnership(cidr, currentOwner, bestOwner); err != nil { + logger.Error("Failed to transfer ownership of %s from site %d to site %d: %v", + cidr, currentOwner, bestOwner, err) + } + } +} + +// startRouteOptimizer registers the status-change callback and launches the optimizer goroutine. +func (pm *PeerManager) startRouteOptimizer() { + pm.routeOptimizerStop = make(chan struct{}) + + // Trigger optimization whenever any peer's connection status changes + if pm.peerMonitor != nil { + pm.peerMonitor.SetStatusChangeCallback(func(_ int) { + select { + case pm.optimizerTrigger <- struct{}{}: + default: + } + }) + } + + go func() { + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + for { + select { + case <-pm.routeOptimizerStop: + return + case <-pm.optimizerTrigger: + pm.optimizeRoutes() + case <-ticker.C: + pm.optimizeRoutes() + } + } + }() +} + +// stopRouteOptimizer stops the route optimizer goroutine if it is running. +func (pm *PeerManager) stopRouteOptimizer() { + if pm.routeOptimizerStop != nil { + close(pm.routeOptimizerStop) + pm.routeOptimizerStop = nil + } +} diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 56dcee4..b7af451 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -85,7 +85,9 @@ type PeerMonitor struct { apiServer *api.API // WG connection status tracking - wgConnectionStatus map[int]bool // siteID -> WG connected status + wgConnectionStatus map[int]bool // siteID -> WG connected status + wgConnectionRTT map[int]time.Duration // siteID -> last known RTT + statusChangeCallback func(siteId int) // called when any peer's connection status changes } // NewPeerMonitor creates a new peer monitor with the given callback @@ -122,6 +124,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total apiServer: apiServer, wgConnectionStatus: make(map[int]bool), + wgConnectionRTT: make(map[int]time.Duration), // Exponential backoff settings for holepunch monitor defaultHolepunchMinInterval: 2 * time.Second, defaultHolepunchMaxInterval: 30 * time.Second, @@ -392,6 +395,9 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio pm.mutex.Lock() previousStatus, exists := pm.wgConnectionStatus[siteID] pm.wgConnectionStatus[siteID] = status.Connected + if status.Connected && status.RTT > 0 { + pm.wgConnectionRTT[siteID] = status.RTT + } isRelayed := pm.relayedPeers[siteID] endpoint := pm.holepunchEndpoints[siteID] pm.mutex.Unlock() @@ -409,6 +415,11 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio if pm.apiServer != nil { pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed) } + + // Notify route optimizer of status change + if pm.statusChangeCallback != nil { + pm.statusChangeCallback(siteID) + } } // sendRelay sends a relay message to the server with retry, keyed by chainId @@ -521,6 +532,25 @@ func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool { return pm.relayedPeers[siteID] } +// SetStatusChangeCallback registers a callback that is invoked whenever a peer's +// WireGuard connection status changes (connected/disconnected). The callback must +// be non-blocking (e.g., send to a buffered channel). +func (pm *PeerMonitor) SetStatusChangeCallback(cb func(siteId int)) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + pm.statusChangeCallback = cb +} + +// GetConnectionQuality returns the current connection quality metrics for a peer. +func (pm *PeerMonitor) GetConnectionQuality(siteId int) (connected bool, relayed bool, rtt time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + connected = pm.wgConnectionStatus[siteId] + relayed = pm.relayedPeers[siteId] + rtt = pm.wgConnectionRTT[siteId] + return +} + // startHolepunchMonitor starts the holepunch connection monitoring // Note: This function assumes the mutex is already held by the caller (called from Start()) func (pm *PeerMonitor) startHolepunchMonitor() error {