mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
Clean up and add unrelay
This commit is contained in:
30
olm/olm.go
30
olm/olm.go
@@ -686,15 +686,41 @@ func StartTunnel(config TunnelConfig) {
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.RelayEndpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Update HTTP server to mark this peer as using relay
|
||||
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.RelayEndpoint, true)
|
||||
|
||||
peerManager.RelayPeer(relayData.SiteId, primaryRelay)
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/peer/unrelay", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received unrelay-peer message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var relayData peers.UnRelayPeerData
|
||||
if err := json.Unmarshal(jsonData, &relayData); err != nil {
|
||||
logger.Error("Error unmarshaling relay data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
primaryRelay, err := util.ResolveDomain(relayData.Endpoint)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to resolve primary relay endpoint: %v", err)
|
||||
}
|
||||
|
||||
// Update HTTP server to mark this peer as using relay
|
||||
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, true)
|
||||
apiServer.UpdatePeerRelayStatus(relayData.SiteId, relayData.Endpoint, false)
|
||||
|
||||
peerManager.HandleFailover(relayData.SiteId, primaryRelay)
|
||||
peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
|
||||
})
|
||||
|
||||
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||
|
||||
109
peers/manager.go
109
peers/manager.go
@@ -6,11 +6,11 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
olmDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
@@ -20,10 +20,6 @@ import (
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
|
||||
// HolepunchStatusCallback is called when holepunch connection status changes
|
||||
// This is an alias for monitor.HolepunchStatusCallback
|
||||
type HolepunchStatusCallback = monitor.HolepunchStatusCallback
|
||||
|
||||
// PeerManagerConfig contains the configuration for creating a PeerManager
|
||||
type PeerManagerConfig struct {
|
||||
Device *device.Device
|
||||
@@ -71,34 +67,6 @@ func NewPeerManager(config PeerManagerConfig) *PeerManager {
|
||||
|
||||
// Create the peer monitor
|
||||
pm.peerMonitor = monitor.NewPeerMonitor(
|
||||
func(siteID int, connected bool, rtt time.Duration) {
|
||||
// Update API status directly
|
||||
if pm.APIServer != nil {
|
||||
// Find the peer config to get endpoint information
|
||||
pm.mu.RLock()
|
||||
peer, exists := pm.peers[siteID]
|
||||
pm.mu.RUnlock()
|
||||
|
||||
var endpoint string
|
||||
var isRelay bool
|
||||
if exists {
|
||||
if peer.RelayEndpoint != "" {
|
||||
endpoint = peer.RelayEndpoint
|
||||
isRelay = true
|
||||
} else {
|
||||
endpoint = peer.Endpoint
|
||||
isRelay = false
|
||||
}
|
||||
}
|
||||
pm.APIServer.UpdatePeerStatus(siteID, connected, rtt, endpoint, isRelay)
|
||||
}
|
||||
|
||||
if connected {
|
||||
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
||||
} else {
|
||||
logger.Warn("Peer %d is disconnected", siteID)
|
||||
}
|
||||
},
|
||||
config.WSClient,
|
||||
config.MiddleDev,
|
||||
config.LocalIP,
|
||||
@@ -677,11 +645,16 @@ func (pm *PeerManager) RemoveAlias(siteId int, aliasName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// HandleFailover handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerManager) HandleFailover(siteId int, relayEndpoint string) {
|
||||
pm.mu.RLock()
|
||||
// RelayPeer handles failover to the relay server when a peer is disconnected
|
||||
func (pm *PeerManager) RelayPeer(siteId int, relayEndpoint string) {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
pm.mu.RUnlock()
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.RelayEndpoint = relayEndpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
@@ -697,7 +670,7 @@ func (pm *PeerManager) HandleFailover(siteId int, relayEndpoint string) {
|
||||
// Update only the endpoint for this peer (update_only preserves other settings)
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s:21820`, peer.PublicKey, formattedEndpoint)
|
||||
endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
@@ -705,6 +678,11 @@ endpoint=%s:21820`, peer.PublicKey, formattedEndpoint)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark the peer as relayed in the monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, true)
|
||||
}
|
||||
|
||||
logger.Info("Adjusted peer %d to point to relay!\n", siteId)
|
||||
}
|
||||
|
||||
@@ -730,9 +708,58 @@ func (pm *PeerManager) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetHolepunchStatusCallback sets the callback for holepunch status changes
|
||||
func (pm *PeerManager) SetHolepunchStatusCallback(callback HolepunchStatusCallback) {
|
||||
// MarkPeerRelayed marks a peer as currently using relay
|
||||
func (pm *PeerManager) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
pm.mu.Lock()
|
||||
if peer, exists := pm.peers[siteID]; exists {
|
||||
if relayed {
|
||||
// We're being relayed, store the current endpoint as the original
|
||||
// (RelayEndpoint is set by HandleFailover)
|
||||
} else {
|
||||
// Clear relay endpoint when switching back to direct
|
||||
peer.RelayEndpoint = ""
|
||||
pm.peers[siteID] = peer
|
||||
}
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.SetHolepunchStatusCallback(callback)
|
||||
pm.peerMonitor.MarkPeerRelayed(siteID, relayed)
|
||||
}
|
||||
}
|
||||
|
||||
// UnRelayPeer switches a peer from relay back to direct connection
|
||||
func (pm *PeerManager) UnRelayPeer(siteId int, endpoint string) error {
|
||||
pm.mu.Lock()
|
||||
peer, exists := pm.peers[siteId]
|
||||
if exists {
|
||||
// Store the relay endpoint
|
||||
peer.Endpoint = endpoint
|
||||
pm.peers[siteId] = peer
|
||||
}
|
||||
pm.mu.Unlock()
|
||||
|
||||
if !exists {
|
||||
logger.Error("Cannot handle failover: peer with site ID %d not found", siteId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update WireGuard to use the direct endpoint
|
||||
wgConfig := fmt.Sprintf(`public_key=%s
|
||||
update_only=true
|
||||
endpoint=%s`, util.FixKey(peer.PublicKey), endpoint)
|
||||
|
||||
err := pm.device.IpcSet(wgConfig)
|
||||
if err != nil {
|
||||
logger.Error("Failed to switch peer %d to direct connection: %v", siteId, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark as not relayed in monitor
|
||||
if pm.peerMonitor != nil {
|
||||
pm.peerMonitor.MarkPeerRelayed(siteId, false)
|
||||
}
|
||||
|
||||
logger.Info("Switched peer %d back to direct connection at %s", siteId, endpoint)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -25,16 +25,9 @@ import (
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||
|
||||
// HolepunchStatusCallback is called when holepunch connection status changes
|
||||
type HolepunchStatusCallback func(siteID int, endpoint string, connected bool, rtt time.Duration)
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*Client
|
||||
callback PeerMonitorCallback
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
@@ -54,36 +47,42 @@ type PeerMonitor struct {
|
||||
nsWg sync.WaitGroup
|
||||
|
||||
// Holepunch testing fields
|
||||
sharedBind *bind.SharedBind
|
||||
holepunchTester *holepunch.HolepunchTester
|
||||
holepunchInterval time.Duration
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStatusCallback HolepunchStatusCallback
|
||||
holepunchStopChan chan struct{}
|
||||
sharedBind *bind.SharedBind
|
||||
holepunchTester *holepunch.HolepunchTester
|
||||
holepunchInterval time.Duration
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStopChan chan struct{}
|
||||
|
||||
// Relay tracking fields
|
||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
||||
holepunchFailures map[int]int // siteID -> consecutive failure count
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor {
|
||||
func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*Client),
|
||||
callback: callback,
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
timeout: 2500 * time.Millisecond,
|
||||
maxAttempts: 15,
|
||||
wsClient: wsClient,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
activePorts: make(map[uint16]bool),
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
sharedBind: sharedBind,
|
||||
holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds
|
||||
holepunchTimeout: 3 * time.Second,
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
monitors: make(map[int]*Client),
|
||||
interval: 1 * time.Second, // Default check interval
|
||||
timeout: 2500 * time.Millisecond,
|
||||
maxAttempts: 15,
|
||||
wsClient: wsClient,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
activePorts: make(map[uint16]bool),
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
sharedBind: sharedBind,
|
||||
holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds
|
||||
holepunchTimeout: 3 * time.Second,
|
||||
holepunchEndpoints: make(map[int]string),
|
||||
holepunchStatus: make(map[int]bool),
|
||||
relayedPeers: make(map[int]bool),
|
||||
holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures
|
||||
holepunchFailures: make(map[int]int),
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
@@ -201,6 +200,8 @@ func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
// remove the holepunch endpoint info
|
||||
delete(pm.holepunchEndpoints, siteID)
|
||||
delete(pm.holepunchStatus, siteID)
|
||||
delete(pm.relayedPeers, siteID)
|
||||
delete(pm.holepunchFailures, siteID)
|
||||
|
||||
pm.removePeerUnlocked(siteID)
|
||||
}
|
||||
@@ -234,17 +235,6 @@ func (pm *PeerMonitor) Start() {
|
||||
|
||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||
func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) {
|
||||
// Call the user-provided callback first
|
||||
if pm.callback != nil {
|
||||
pm.callback(siteID, status.Connected, status.RTT)
|
||||
}
|
||||
|
||||
// If disconnected, send relay message to the server
|
||||
if !status.Connected {
|
||||
if pm.wsClient != nil {
|
||||
pm.sendRelay(siteID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
@@ -264,6 +254,23 @@ func (pm *PeerMonitor) sendRelay(siteID int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// sendRelay sends a relay message to the server
|
||||
func (pm *PeerMonitor) sendUnRelay(siteID int) error {
|
||||
if pm.wsClient == nil {
|
||||
return fmt.Errorf("websocket client is nil")
|
||||
}
|
||||
|
||||
err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{
|
||||
"siteId": siteID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Sent unrelay message")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||
@@ -284,11 +291,15 @@ func (pm *PeerMonitor) Stop() {
|
||||
}
|
||||
}
|
||||
|
||||
// SetHolepunchStatusCallback sets the callback for holepunch status changes
|
||||
func (pm *PeerMonitor) SetHolepunchStatusCallback(callback HolepunchStatusCallback) {
|
||||
// MarkPeerRelayed marks a peer as currently using relay
|
||||
func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
pm.holepunchStatusCallback = callback
|
||||
pm.relayedPeers[siteID] = relayed
|
||||
if relayed {
|
||||
// Reset failure count when marked as relayed
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
}
|
||||
}
|
||||
|
||||
// startHolepunchMonitor starts the holepunch connection monitoring
|
||||
@@ -358,6 +369,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
endpoints[siteID] = endpoint
|
||||
}
|
||||
timeout := pm.holepunchTimeout
|
||||
maxAttempts := pm.holepunchMaxAttempts
|
||||
pm.mutex.Unlock()
|
||||
|
||||
for siteID, endpoint := range endpoints {
|
||||
@@ -366,7 +378,15 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
pm.mutex.Lock()
|
||||
previousStatus, exists := pm.holepunchStatus[siteID]
|
||||
pm.holepunchStatus[siteID] = result.Success
|
||||
callback := pm.holepunchStatusCallback
|
||||
isRelayed := pm.relayedPeers[siteID]
|
||||
|
||||
// Track consecutive failures for relay triggering
|
||||
if result.Success {
|
||||
pm.holepunchFailures[siteID] = 0
|
||||
} else {
|
||||
pm.holepunchFailures[siteID]++
|
||||
}
|
||||
failureCount := pm.holepunchFailures[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
// Log status changes
|
||||
@@ -382,9 +402,19 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||
}
|
||||
}
|
||||
|
||||
// Call the callback if set
|
||||
if callback != nil {
|
||||
callback(siteID, endpoint, result.Success, result.RTT)
|
||||
// Handle relay logic based on holepunch status
|
||||
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
||||
// Holepunch failed and we're not relayed - trigger relay
|
||||
logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount)
|
||||
if pm.wsClient != nil {
|
||||
pm.sendRelay(siteID)
|
||||
}
|
||||
} else if result.Success && isRelayed {
|
||||
// Holepunch succeeded and we ARE relayed - switch back to direct
|
||||
logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID)
|
||||
if pm.wsClient != nil {
|
||||
pm.sendUnRelay(siteID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,9 +30,13 @@ type PeerRemove struct {
|
||||
}
|
||||
|
||||
type RelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
SiteId int `json:"siteId"`
|
||||
RelayEndpoint string `json:"relayEndpoint"`
|
||||
}
|
||||
|
||||
type UnRelayPeerData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
// PeerAdd represents the data needed to add remote subnets to a peer
|
||||
|
||||
Reference in New Issue
Block a user