Handle canceling sends for relay

This commit is contained in:
Owen
2026-03-06 15:15:31 -08:00
parent 051c0fdfd8
commit c67c2a60a1
2 changed files with 81 additions and 20 deletions

View File

@@ -172,12 +172,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.RelayPeerData var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if relayData.ChainId != "" {
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint) primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil { if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err) logger.Error("Failed to resolve primary relay endpoint: %v", err)
@@ -205,12 +214,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return return
} }
var relayData peers.UnRelayPeerData var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil { if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err) logger.Error("Error unmarshaling relay data: %v", err)
return return
} }
if relayData.ChainId != "" {
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint) primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil { if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err) logger.Warn("Failed to resolve primary relay endpoint: %v", err)

View File

@@ -2,6 +2,8 @@ package monitor
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
@@ -35,6 +37,10 @@ type PeerMonitor struct {
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields // Netstack fields
middleDev *middleDevice.MiddleDevice middleDev *middleDevice.MiddleDevice
localIP string localIP string
@@ -82,6 +88,12 @@ type PeerMonitor struct {
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
@@ -99,6 +111,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string), holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool), holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool), relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int), holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds // Rapid initial test settings: complete within ~1.5 seconds
@@ -396,20 +409,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
} }
} }
// sendRelay sends a relay message to the server // sendRelay sends a relay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendRelay(siteID int) error { func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ chainId := generateChainId()
"siteId": siteID, stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
}) "siteId": siteID,
if err != nil { "chainId": chainId,
logger.Error("Failed to send registration message: %v", err) }, 2*time.Second, 10)
return err
} pm.relaySendMu.Lock()
logger.Info("Sent relay message") pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
@@ -419,23 +435,40 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID) return pm.sendRelay(siteID)
} }
// sendUnRelay sends an unrelay message to the server // sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendUnRelay(siteID int) error { func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil { if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil") return fmt.Errorf("websocket client is nil")
} }
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ chainId := generateChainId()
"siteId": siteID, stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
}) "siteId": siteID,
if err != nil { "chainId": chainId,
logger.Error("Failed to send registration message: %v", err) }, 2*time.Second, 10)
return err
} pm.relaySendMu.Lock()
logger.Info("Sent unrelay message") pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil return nil
} }
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers // Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() { func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -677,6 +710,16 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock) // Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor() pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()