diff --git a/olm/olm.go b/olm/olm.go index ee36c29..3035cbd 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -686,15 +686,41 @@ func StartTunnel(config TunnelConfig) { return } + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) + if err != nil { + logger.Warn("Failed to resolve primary relay endpoint: %v", err) + } + + // Update HTTP server to mark this peer as using relay + apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true) + + peerManager.RelayPeer(relayData.SiteId, primaryRelay) + }) + + olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) { + logger.Debug("Received unrelay-peer message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling data: %v", err) + return + } + + var relayData peers.UnRelayPeerData + if err := json.Unmarshal(jsonData, &relayData); err != nil { + logger.Error("Error unmarshaling relay data: %v", err) + return + } + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) } // Update HTTP server to mark this peer as using relay - apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true) + apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false) - peerManager.HandleFailover(relayData.SiteId, primaryRelay) + peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) }) // Handler for peer handshake - adds exit node to holepunch rotation and notifies server diff --git a/peers/manager.go b/peers/manager.go index 4cd8332..fe71a19 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -6,11 +6,11 @@ import ( "strconv" "strings" "sync" - "time" "github.com/fosrl/newt/bind" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/network" + "github.com/fosrl/newt/util" "github.com/fosrl/olm/api" olmDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/dns" @@ -20,10 +20,6 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -// HolepunchStatusCallback is called when holepunch connection status changes -// This is an alias for monitor.HolepunchStatusCallback -type HolepunchStatusCallback = monitor.HolepunchStatusCallback - // PeerManagerConfig contains the configuration for creating a PeerManager type PeerManagerConfig struct { Device *device.Device @@ -71,34 +67,6 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager { // Create the peer monitor pm.peerMonitor = monitor.NewPeerMonitor( - func(siteID int, connected bool, rtt time.Duration) { - // Update API status directly - if pm.APIServer != nil { - // Find the peer config to get endpoint information - pm.mu.RLock() - peer, exists := pm.peers[siteID] - pm.mu.RUnlock() - - var endpoint string - var isRelay bool - if exists { - if peer.RelayEndpoint != "" { - endpoint = peer.RelayEndpoint - isRelay = true - } else { - endpoint = peer.Endpoint - isRelay = false - } - } - pm.APIServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay) - } - - if connected { - logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) - } else { - logger.Warn("Peer %d is disconnected", siteID) - } - }, config.WSClient, config.MiddleDev, config.LocalIP, @@ -677,11 +645,16 @@ func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error { return nil } -// HandleFailover handles failover to the relay server when a peer is disconnected -func (pm *PeerManager) HandleFailover(siteId int, relayEndpoint string) { - pm.mu.RLock() +// RelayPeer handles failover to the relay server when a peer is disconnected +func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) { + pm.mu.Lock() peer, exists := pm.peers[siteId] - pm.mu.RUnlock() + if exists { + // Store the relay endpoint + peer.RelayEndpoint = relayEndpoint + pm.peers[siteId] = peer + } + pm.mu.Unlock() if !exists { logger.Error("Cannot handle failover: peer with site ID %d not found", siteId) @@ -697,7 +670,7 @@ func (pm *PeerManager) HandleFailover(siteId int, relayEndpoint string) { // Update only the endpoint for this peer (update_only preserves other settings) wgConfig := fmt.Sprintf(`public_key=%s update_only=true -endpoint=%s:21820`, peer.PublicKey, formattedEndpoint) +endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint) err := pm.device.IpcSet(wgConfig) if err != nil { @@ -705,6 +678,11 @@ endpoint=%s:21820`, peer.PublicKey, formattedEndpoint) return } + // Mark the peer as relayed in the monitor + if pm.peerMonitor != nil { + pm.peerMonitor.MarkPeerRelayed(siteId, true) + } + logger.Info("Adjusted peer %d to point to relay!\n", siteId) } @@ -730,9 +708,58 @@ func (pm *PeerManager) Close() { } } -// SetHolepunchStatusCallback sets the callback for holepunch status changes -func (pm *PeerManager) SetHolepunchStatusCallback(callback HolepunchStatusCallback) { +// MarkPeerRelayed marks a peer as currently using relay +func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) { + pm.mu.Lock() + if peer, exists := pm.peers[siteID]; exists { + if relayed { + // We're being relayed, store the current endpoint as the original + // (RelayEndpoint is set by HandleFailover) + } else { + // Clear relay endpoint when switching back to direct + peer.RelayEndpoint = "" + pm.peers[siteID] = peer + } + } + pm.mu.Unlock() + if pm.peerMonitor != nil { - pm.peerMonitor.SetHolepunchStatusCallback(callback) + pm.peerMonitor.MarkPeerRelayed(siteID, relayed) } } + +// UnRelayPeer switches a peer from relay back to direct connection +func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error { + pm.mu.Lock() + peer, exists := pm.peers[siteId] + if exists { + // Store the relay endpoint + peer.Endpoint = endpoint + pm.peers[siteId] = peer + } + pm.mu.Unlock() + + if !exists { + logger.Error("Cannot handle failover: peer with site ID %d not found", siteId) + return nil + } + + // Update WireGuard to use the direct endpoint + wgConfig := fmt.Sprintf(`public_key=%s +update_only=true +endpoint=%s`, util.FixKey(peer.PublicKey), endpoint) + + err := pm.device.IpcSet(wgConfig) + if err != nil { + logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err) + return err + } + + // Mark as not relayed in monitor + if pm.peerMonitor != nil { + pm.peerMonitor.MarkPeerRelayed(siteId, false) + } + + logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint) + return nil +} diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index d7055d2..59bbbef 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -25,16 +25,9 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -// PeerMonitorCallback is the function type for connection status change callbacks -type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) - -// HolepunchStatusCallback is called when holepunch connection status changes -type HolepunchStatusCallback func(siteID int, endpoint string, connected bool, rtt time.Duration) - // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { monitors map[int]*Client - callback PeerMonitorCallback mutex sync.Mutex running bool interval time.Duration @@ -54,36 +47,42 @@ type PeerMonitor struct { nsWg sync.WaitGroup // Holepunch testing fields - sharedBind *bind.SharedBind - holepunchTester *holepunch.HolepunchTester - holepunchInterval time.Duration - holepunchTimeout time.Duration - holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing - holepunchStatus map[int]bool // siteID -> connected status - holepunchStatusCallback HolepunchStatusCallback - holepunchStopChan chan struct{} + sharedBind *bind.SharedBind + holepunchTester *holepunch.HolepunchTester + holepunchInterval time.Duration + holepunchTimeout time.Duration + holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing + holepunchStatus map[int]bool // siteID -> connected status + holepunchStopChan chan struct{} + + // Relay tracking fields + relayedPeers map[int]bool // siteID -> whether the peer is currently relayed + holepunchMaxAttempts int // max consecutive failures before triggering relay + holepunchFailures map[int]int // siteID -> consecutive failure count } // NewPeerMonitor creates a new peer monitor with the given callback -func NewPeerMonitor(callback PeerMonitorCallback, wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor { +func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ - monitors: make(map[int]*Client), - callback: callback, - interval: 1 * time.Second, // Default check interval - timeout: 2500 * time.Millisecond, - maxAttempts: 15, - wsClient: wsClient, - middleDev: middleDev, - localIP: localIP, - activePorts: make(map[uint16]bool), - nsCtx: ctx, - nsCancel: cancel, - sharedBind: sharedBind, - holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds - holepunchTimeout: 3 * time.Second, - holepunchEndpoints: make(map[int]string), - holepunchStatus: make(map[int]bool), + monitors: make(map[int]*Client), + interval: 1 * time.Second, // Default check interval + timeout: 2500 * time.Millisecond, + maxAttempts: 15, + wsClient: wsClient, + middleDev: middleDev, + localIP: localIP, + activePorts: make(map[uint16]bool), + nsCtx: ctx, + nsCancel: cancel, + sharedBind: sharedBind, + holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds + holepunchTimeout: 3 * time.Second, + holepunchEndpoints: make(map[int]string), + holepunchStatus: make(map[int]bool), + relayedPeers: make(map[int]bool), + holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures + holepunchFailures: make(map[int]int), } if err := pm.initNetstack(); err != nil { @@ -201,6 +200,8 @@ func (pm *PeerMonitor) RemovePeer(siteID int) { // remove the holepunch endpoint info delete(pm.holepunchEndpoints, siteID) delete(pm.holepunchStatus, siteID) + delete(pm.relayedPeers, siteID) + delete(pm.holepunchFailures, siteID) pm.removePeerUnlocked(siteID) } @@ -234,17 +235,6 @@ func (pm *PeerMonitor) Start() { // handleConnectionStatusChange is called when a peer's connection status changes func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) { - // Call the user-provided callback first - if pm.callback != nil { - pm.callback(siteID, status.Connected, status.RTT) - } - - // If disconnected, send relay message to the server - if !status.Connected { - if pm.wsClient != nil { - pm.sendRelay(siteID) - } - } } // sendRelay sends a relay message to the server @@ -264,6 +254,23 @@ func (pm *PeerMonitor) sendRelay(siteID int) error { return nil } +// sendRelay sends a relay message to the server +func (pm *PeerMonitor) sendUnRelay(siteID int) error { + if pm.wsClient == nil { + return fmt.Errorf("websocket client is nil") + } + + err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ + "siteId": siteID, + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + logger.Info("Sent unrelay message") + return nil +} + // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) @@ -284,11 +291,15 @@ func (pm *PeerMonitor) Stop() { } } -// SetHolepunchStatusCallback sets the callback for holepunch status changes -func (pm *PeerMonitor) SetHolepunchStatusCallback(callback HolepunchStatusCallback) { +// MarkPeerRelayed marks a peer as currently using relay +func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.holepunchStatusCallback = callback + pm.relayedPeers[siteID] = relayed + if relayed { + // Reset failure count when marked as relayed + pm.holepunchFailures[siteID] = 0 + } } // startHolepunchMonitor starts the holepunch connection monitoring @@ -358,6 +369,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { endpoints[siteID] = endpoint } timeout := pm.holepunchTimeout + maxAttempts := pm.holepunchMaxAttempts pm.mutex.Unlock() for siteID, endpoint := range endpoints { @@ -366,7 +378,15 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Lock() previousStatus, exists := pm.holepunchStatus[siteID] pm.holepunchStatus[siteID] = result.Success - callback := pm.holepunchStatusCallback + isRelayed := pm.relayedPeers[siteID] + + // Track consecutive failures for relay triggering + if result.Success { + pm.holepunchFailures[siteID] = 0 + } else { + pm.holepunchFailures[siteID]++ + } + failureCount := pm.holepunchFailures[siteID] pm.mutex.Unlock() // Log status changes @@ -382,9 +402,19 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { } } - // Call the callback if set - if callback != nil { - callback(siteID, endpoint, result.Success, result.RTT) + // Handle relay logic based on holepunch status + if !result.Success && !isRelayed && failureCount >= maxAttempts { + // Holepunch failed and we're not relayed - trigger relay + logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount) + if pm.wsClient != nil { + pm.sendRelay(siteID) + } + } else if result.Success && isRelayed { + // Holepunch succeeded and we ARE relayed - switch back to direct + logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID) + if pm.wsClient != nil { + pm.sendUnRelay(siteID) + } } } } diff --git a/peers/types.go b/peers/types.go index 49d0924..b2867b3 100644 --- a/peers/types.go +++ b/peers/types.go @@ -30,9 +30,13 @@ type PeerRemove struct { } type RelayPeerData struct { - SiteId int `json:"siteId"` - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` + SiteId int `json:"siteId"` + RelayEndpoint string `json:"relayEndpoint"` +} + +type UnRelayPeerData struct { + SiteId int `json:"siteId"` + Endpoint string `json:"endpoint"` } // PeerAdd represents the data needed to add remote subnets to a peer