Handle canceling sends for relay

This commit is contained in:
Owen
2026-03-06 15:15:31 -08:00
parent 051c0fdfd8
commit c67c2a60a1
2 changed files with 81 additions and 20 deletions

View File

@@ -172,12 +172,21 @@ func (o *Olm) handleWgPeerRelay(msg websocket.WSMessage) {
return
}
var relayData peers.RelayPeerData
var relayData struct {
peers.RelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
if relayData.ChainId != "" {
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil {
logger.Error("Failed to resolve primary relay endpoint: %v", err)
@@ -205,12 +214,21 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
return
}
var relayData peers.UnRelayPeerData
var relayData struct {
peers.UnRelayPeerData
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
if relayData.ChainId != "" {
if monitor := o.peerManager.GetPeerMonitor(); monitor != nil {
monitor.CancelRelaySend(relayData.ChainId)
}
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)

View File

@@ -2,6 +2,8 @@ package monitor
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"net"
"net/netip"
@@ -35,6 +37,10 @@ type PeerMonitor struct {
maxAttempts int
wsClient *websocket.Client
// Relay sender tracking
relaySends map[string]func()
relaySendMu sync.Mutex
// Netstack fields
middleDev *middleDevice.MiddleDevice
localIP string
@@ -82,6 +88,12 @@ type PeerMonitor struct {
}
// NewPeerMonitor creates a new peer monitor with the given callback
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
@@ -99,6 +111,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool),
relaySends: make(map[string]func()),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int),
// Rapid initial test settings: complete within ~1.5 seconds
@@ -396,20 +409,23 @@ func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status Connectio
}
}
// sendRelay sends a relay message to the server
// sendRelay sends a relay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent relay message")
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/relay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent relay message for site %d (chain %s)", siteID, chainId)
return nil
}
@@ -419,23 +435,40 @@ func (pm *PeerMonitor) RequestRelay(siteID int) error {
return pm.sendRelay(siteID)
}
// sendUnRelay sends an unrelay message to the server
// sendUnRelay sends an unrelay message to the server with retry, keyed by chainId
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent unrelay message")
chainId := generateChainId()
stopFunc, _ := pm.wsClient.SendMessageInterval("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
"chainId": chainId,
}, 2*time.Second, 10)
pm.relaySendMu.Lock()
pm.relaySends[chainId] = stopFunc
pm.relaySendMu.Unlock()
logger.Info("Sent unrelay message for site %d (chain %s)", siteID, chainId)
return nil
}
// CancelRelaySend stops the interval sender for the given chainId, if one exists.
func (pm *PeerMonitor) CancelRelaySend(chainId string) {
pm.relaySendMu.Lock()
defer pm.relaySendMu.Unlock()
if stop, ok := pm.relaySends[chainId]; ok {
stop()
delete(pm.relaySends, chainId)
logger.Info("Cancelled relay sender for chain %s", chainId)
} else {
logger.Warn("CancelRelaySend: no active sender for chain %s", chainId)
}
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -677,6 +710,16 @@ func (pm *PeerMonitor) Close() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
pm.stopHolepunchMonitor()
// Stop all pending relay senders
pm.relaySendMu.Lock()
for chainId, stop := range pm.relaySends {
if stop != nil {
stop()
}
delete(pm.relaySends, chainId)
}
pm.relaySendMu.Unlock()
pm.mutex.Lock()
defer pm.mutex.Unlock()