Clean up and add unrelay

Former-commit-id: 01586510f3
This commit is contained in:
Owen
2025-12-02 10:45:30 -05:00
parent 51162d6be6
commit 2106734aa4
4 changed files with 183 additions and 96 deletions

View File

@@ -686,15 +686,41 @@ func StartTunnel(config TunnelConfig) {
return
}
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
peerManager.RelayPeer(relayData.SiteId, primaryRelay)
})
olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) {
logger.Debug("Received unrelay-peer message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling data: %v", err)
return
}
var relayData peers.UnRelayPeerData
if err := json.Unmarshal(jsonData, &relayData); err != nil {
logger.Error("Error unmarshaling relay data: %v", err)
return
}
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
if err != nil {
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
}
// Update HTTP server to mark this peer as using relay
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
peerManager.HandleFailover(relayData.SiteId, primaryRelay)
peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
})
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server

View File

@@ -6,11 +6,11 @@ import (
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/newt/bind"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/api"
olmDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
@@ -20,10 +20,6 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// HolepunchStatusCallback is called when holepunch connection status changes
// This is an alias for monitor.HolepunchStatusCallback
type HolepunchStatusCallback = monitor.HolepunchStatusCallback
// PeerManagerConfig contains the configuration for creating a PeerManager
type PeerManagerConfig struct {
Device *device.Device
@@ -71,34 +67,6 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
// Create the peer monitor
pm.peerMonitor = monitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
// Update API status directly
if pm.APIServer != nil {
// Find the peer config to get endpoint information
pm.mu.RLock()
peer, exists := pm.peers[siteID]
pm.mu.RUnlock()
var endpoint string
var isRelay bool
if exists {
if peer.RelayEndpoint != "" {
endpoint = peer.RelayEndpoint
isRelay = true
} else {
endpoint = peer.Endpoint
isRelay = false
}
}
pm.APIServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
}
if connected {
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
} else {
logger.Warn("Peer %d is disconnected", siteID)
}
},
config.WSClient,
config.MiddleDev,
config.LocalIP,
@@ -677,11 +645,16 @@ func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
return nil
}
// HandleFailover handles failover to the relay server when a peer is disconnected
func (pm *PeerManager) HandleFailover(siteId int, relayEndpoint string) {
pm.mu.RLock()
// RelayPeer handles failover to the relay server when a peer is disconnected
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
pm.mu.Lock()
peer, exists := pm.peers[siteId]
pm.mu.RUnlock()
if exists {
// Store the relay endpoint
peer.RelayEndpoint = relayEndpoint
pm.peers[siteId] = peer
}
pm.mu.Unlock()
if !exists {
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
@@ -697,7 +670,7 @@ func (pm *PeerManager) HandleFailover(siteId int, relayEndpoint string) {
// Update only the endpoint for this peer (update_only preserves other settings)
wgConfig := fmt.Sprintf(`public_key=%s
update_only=true
endpoint=%s:21820`, peer.PublicKey, formattedEndpoint)
endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint)
err := pm.device.IpcSet(wgConfig)
if err != nil {
@@ -705,6 +678,11 @@ endpoint=%s:21820`, peer.PublicKey, formattedEndpoint)
return
}
// Mark the peer as relayed in the monitor
if pm.peerMonitor != nil {
pm.peerMonitor.MarkPeerRelayed(siteId, true)
}
logger.Info("Adjusted peer %d to point to relay!\n", siteId)
}
@@ -730,9 +708,58 @@ func (pm *PeerManager) Close() {
}
}
// SetHolepunchStatusCallback sets the callback for holepunch status changes
func (pm *PeerManager) SetHolepunchStatusCallback(callback HolepunchStatusCallback) {
// MarkPeerRelayed marks a peer as currently using relay
func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) {
pm.mu.Lock()
if peer, exists := pm.peers[siteID]; exists {
if relayed {
// We're being relayed, store the current endpoint as the original
// (RelayEndpoint is set by HandleFailover)
} else {
// Clear relay endpoint when switching back to direct
peer.RelayEndpoint = ""
pm.peers[siteID] = peer
}
}
pm.mu.Unlock()
if pm.peerMonitor != nil {
pm.peerMonitor.SetHolepunchStatusCallback(callback)
pm.peerMonitor.MarkPeerRelayed(siteID, relayed)
}
}
// UnRelayPeer switches a peer from relay back to direct connection
func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error {
pm.mu.Lock()
peer, exists := pm.peers[siteId]
if exists {
// Store the relay endpoint
peer.Endpoint = endpoint
pm.peers[siteId] = peer
}
pm.mu.Unlock()
if !exists {
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
return nil
}
// Update WireGuard to use the direct endpoint
wgConfig := fmt.Sprintf(`public_key=%s
update_only=true
endpoint=%s`, util.FixKey(peer.PublicKey), endpoint)
err := pm.device.IpcSet(wgConfig)
if err != nil {
logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err)
return err
}
// Mark as not relayed in monitor
if pm.peerMonitor != nil {
pm.peerMonitor.MarkPeerRelayed(siteId, false)
}
logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint)
return nil
}

View File

@@ -25,16 +25,9 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
// PeerMonitorCallback is the function type for connection status change callbacks
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
// HolepunchStatusCallback is called when holepunch connection status changes
type HolepunchStatusCallback func(siteID int, endpoint string, connected bool, rtt time.Duration)
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
type PeerMonitor struct {
monitors map[int]*Client
callback PeerMonitorCallback
mutex sync.Mutex
running bool
interval time.Duration
@@ -54,36 +47,42 @@ type PeerMonitor struct {
nsWg sync.WaitGroup
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStatusCallback HolepunchStatusCallback
holepunchStopChan chan struct{}
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
holepunchMaxAttempts int // max consecutive failures before triggering relay
holepunchFailures map[int]int // siteID -> consecutive failure count
}
// NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(callback PeerMonitorCallback, wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor {
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
callback: callback,
interval: 1 * time.Second, // Default check interval
timeout: 2500 * time.Millisecond,
maxAttempts: 15,
wsClient: wsClient,
middleDev: middleDev,
localIP: localIP,
activePorts: make(map[uint16]bool),
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds
holepunchTimeout: 3 * time.Second,
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
monitors: make(map[int]*Client),
interval: 1 * time.Second, // Default check interval
timeout: 2500 * time.Millisecond,
maxAttempts: 15,
wsClient: wsClient,
middleDev: middleDev,
localIP: localIP,
activePorts: make(map[uint16]bool),
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds
holepunchTimeout: 3 * time.Second,
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
relayedPeers: make(map[int]bool),
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
holepunchFailures: make(map[int]int),
}
if err := pm.initNetstack(); err != nil {
@@ -201,6 +200,8 @@ func (pm *PeerMonitor) RemovePeer(siteID int) {
// remove the holepunch endpoint info
delete(pm.holepunchEndpoints, siteID)
delete(pm.holepunchStatus, siteID)
delete(pm.relayedPeers, siteID)
delete(pm.holepunchFailures, siteID)
pm.removePeerUnlocked(siteID)
}
@@ -234,17 +235,6 @@ func (pm *PeerMonitor) Start() {
// handleConnectionStatusChange is called when a peer's connection status changes
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) {
// Call the user-provided callback first
if pm.callback != nil {
pm.callback(siteID, status.Connected, status.RTT)
}
// If disconnected, send relay message to the server
if !status.Connected {
if pm.wsClient != nil {
pm.sendRelay(siteID)
}
}
}
// sendRelay sends a relay message to the server
@@ -264,6 +254,23 @@ func (pm *PeerMonitor) sendRelay(siteID int) error {
return nil
}
// sendRelay sends a relay message to the server
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
if pm.wsClient == nil {
return fmt.Errorf("websocket client is nil")
}
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
"siteId": siteID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent unrelay message")
return nil
}
// Stop stops monitoring all peers
func (pm *PeerMonitor) Stop() {
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
@@ -284,11 +291,15 @@ func (pm *PeerMonitor) Stop() {
}
}
// SetHolepunchStatusCallback sets the callback for holepunch status changes
func (pm *PeerMonitor) SetHolepunchStatusCallback(callback HolepunchStatusCallback) {
// MarkPeerRelayed marks a peer as currently using relay
func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.holepunchStatusCallback = callback
pm.relayedPeers[siteID] = relayed
if relayed {
// Reset failure count when marked as relayed
pm.holepunchFailures[siteID] = 0
}
}
// startHolepunchMonitor starts the holepunch connection monitoring
@@ -358,6 +369,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
endpoints[siteID] = endpoint
}
timeout := pm.holepunchTimeout
maxAttempts := pm.holepunchMaxAttempts
pm.mutex.Unlock()
for siteID, endpoint := range endpoints {
@@ -366,7 +378,15 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
pm.mutex.Lock()
previousStatus, exists := pm.holepunchStatus[siteID]
pm.holepunchStatus[siteID] = result.Success
callback := pm.holepunchStatusCallback
isRelayed := pm.relayedPeers[siteID]
// Track consecutive failures for relay triggering
if result.Success {
pm.holepunchFailures[siteID] = 0
} else {
pm.holepunchFailures[siteID]++
}
failureCount := pm.holepunchFailures[siteID]
pm.mutex.Unlock()
// Log status changes
@@ -382,9 +402,19 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
}
}
// Call the callback if set
if callback != nil {
callback(siteID, endpoint, result.Success, result.RTT)
// Handle relay logic based on holepunch status
if !result.Success && !isRelayed && failureCount >= maxAttempts {
// Holepunch failed and we're not relayed - trigger relay
logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)
if pm.wsClient != nil {
pm.sendRelay(siteID)
}
} else if result.Success && isRelayed {
// Holepunch succeeded and we ARE relayed - switch back to direct
logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID)
if pm.wsClient != nil {
pm.sendUnRelay(siteID)
}
}
}
}

View File

@@ -30,9 +30,13 @@ type PeerRemove struct {
}
type RelayPeerData struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
SiteId int `json:"siteId"`
RelayEndpoint string `json:"relayEndpoint"`
}
type UnRelayPeerData struct {
SiteId int `json:"siteId"`
Endpoint string `json:"endpoint"`
}
// PeerAdd represents the data needed to add remote subnets to a peer