diff --git a/olm/peer.go b/olm/peer.go index 1937934..c611921 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -172,12 +172,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) { return } - var relayData peers.RelayPeerData + var relayData struct { + peers.RelayPeerData + ChainId string `json:"chainId"` + } if err := json.Unmarshal(jsonData, &relayData); err != nil { logger.Error("Error unmarshaling relay data: %v", err) return } + if relayData.ChainId != "" { + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) + } + } + primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) if err != nil { logger.Error("Failed to resolve primary relay endpoint: %v", err) @@ -205,12 +214,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { return } - var relayData peers.UnRelayPeerData + var relayData struct { + peers.UnRelayPeerData + ChainId string `json:"chainId"` + } if err := json.Unmarshal(jsonData, &relayData); err != nil { logger.Error("Error unmarshaling relay data: %v", err) return } + if relayData.ChainId != "" { + if monitor := o.peerManager.GetPeerMonitor(); monitor != nil { + monitor.CancelRelaySend(relayData.ChainId) + } + } + primaryRelay, err := util.ResolveDomain(relayData.Endpoint) if err != nil { logger.Warn("Failed to resolve primary relay endpoint: %v", err) diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 28d92ef..1296fef 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -2,6 +2,8 @@ package monitor import ( "context" + "crypto/rand" + "encoding/hex" "fmt" "net" "net/netip" @@ -35,6 +37,10 @@ type PeerMonitor struct { maxAttempts int wsClient *websocket.Client + // Relay sender tracking + relaySends map[string]func() + relaySendMu sync.Mutex + // Netstack fields middleDev *middleDevice.MiddleDevice localIP string @@ -82,6 +88,12 @@ type PeerMonitor struct { } // NewPeerMonitor creates a new peer monitor with the given callback +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ @@ -99,6 +111,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), relayedPeers: make(map[int]bool), + relaySends: make(map[string]func()), holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchFailures: make(map[int]int), // Rapid initial test settings: complete within ~1.5 seconds @@ -396,20 +409,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio } } -// sendRelay sends a relay message to the server +// sendRelay sends a relay message to the server with retry, keyed by chainId func (pm *PeerMonitor) sendRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } - err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ - "siteId": siteID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent relay message") + chainId := generateChainId() + stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{ + "siteId": siteID, + "chainId": chainId, + }, 2*time.Second, 10) + + pm.relaySendMu.Lock() + pm.relaySends[chainId] = stopFunc + pm.relaySendMu.Unlock() + + logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId) return nil } @@ -419,23 +435,40 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error { return pm.sendRelay(siteID) } -// sendUnRelay sends an unrelay message to the server +// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId func (pm *PeerMonitor) sendUnRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } - err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ - "siteId": siteID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent unrelay message") + chainId := generateChainId() + stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{ + "siteId": siteID, + "chainId": chainId, + }, 2*time.Second, 10) + + pm.relaySendMu.Lock() + pm.relaySends[chainId] = stopFunc + pm.relaySendMu.Unlock() + + logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId) return nil } +// CancelRelaySend stops the interval sender for the given chainId, if one exists. +func (pm *PeerMonitor) CancelRelaySend(chainId string) { + pm.relaySendMu.Lock() + defer pm.relaySendMu.Unlock() + + if stop, ok := pm.relaySends[chainId]; ok { + stop() + delete(pm.relaySends, chainId) + logger.Info("Cancelled relay sender for chain %s", chainId) + } else { + logger.Warn("CancelRelaySend: no active sender for chain %s", chainId) + } +} + // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) @@ -677,6 +710,16 @@ func (pm *PeerMonitor) Close() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) pm.stopHolepunchMonitor() + // Stop all pending relay senders + pm.relaySendMu.Lock() + for chainId, stop := range pm.relaySends { + if stop != nil { + stop() + } + delete(pm.relaySends, chainId) + } + pm.relaySendMu.Unlock() + pm.mutex.Lock() defer pm.mutex.Unlock()