diff --git a/olm/olm.go b/olm/olm.go index ddc4e88..264e651 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -419,6 +419,7 @@ func StartTunnel(config TunnelConfig) { config.Holepunch && !config.DisableRelay, // Enable relay only if holepunching is enabled and DisableRelay is false middleDev, interfaceIP, + sharedBind, // Pass sharedBind for holepunch testing ) peerManager = peers.NewPeerManager(dev, peerMonitor, dnsProxy, interfaceName, privateKey) @@ -432,9 +433,20 @@ func StartTunnel(config TunnelConfig) { return } + // Add holepunch monitoring for this endpoint if holepunching is enabled + if config.Holepunch { + peerMonitor.AddHolepunchEndpoint(site.SiteId, site.Endpoint) + } + logger.Info("Configured peer %s", site.PublicKey) } + peerMonitor.SetHolepunchStatusCallback(func(siteID int, endpoint string, connected bool, rtt time.Duration) { + // This callback is for additional handling if needed + // The PeerMonitor already logs status changes + logger.Info("+++++++++++++++++++++++++ holepunch monitor callback for site %d, endpoint %s, connected: %v, rtt: %v", siteID, endpoint, connected, rtt) + }) + peerMonitor.Start() // Set up DNS override to use our DNS proxy diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index dcdd1d9..b83f705 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -9,6 +9,8 @@ import ( "sync" "time" + "github.com/fosrl/newt/bind" + "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" middleDevice "github.com/fosrl/olm/device" @@ -28,6 +30,9 @@ import ( // PeerMonitorCallback is the function type for connection status change callbacks type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) +// HolepunchStatusCallback is called when holepunch connection status changes +type HolepunchStatusCallback func(siteID int, endpoint string, connected bool, rtt time.Duration) + // WireGuardConfig holds the WireGuard configuration for a peer type WireGuardConfig struct { SiteID int @@ -62,33 +67,53 @@ type PeerMonitor struct { nsCtx context.Context nsCancel context.CancelFunc nsWg sync.WaitGroup + + // Holepunch testing fields + sharedBind *bind.SharedBind + holepunchTester *holepunch.HolepunchTester + holepunchInterval time.Duration + holepunchTimeout time.Duration + holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing + holepunchStatus map[int]bool // siteID -> connected status + holepunchStatusCallback HolepunchStatusCallback + holepunchStopChan chan struct{} } // NewPeerMonitor creates a new peer monitor with the given callback -func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor { +func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ - monitors: make(map[int]*Client), - configs: make(map[int]*WireGuardConfig), - callback: callback, - interval: 1 * time.Second, // Default check interval - timeout: 2500 * time.Millisecond, - maxAttempts: 15, - privateKey: privateKey, - wsClient: wsClient, - device: device, - handleRelaySwitch: handleRelaySwitch, - middleDev: middleDev, - localIP: localIP, - activePorts: make(map[uint16]bool), - nsCtx: ctx, - nsCancel: cancel, + monitors: make(map[int]*Client), + configs: make(map[int]*WireGuardConfig), + callback: callback, + interval: 1 * time.Second, // Default check interval + timeout: 2500 * time.Millisecond, + maxAttempts: 15, + privateKey: privateKey, + wsClient: wsClient, + device: device, + handleRelaySwitch: handleRelaySwitch, + middleDev: middleDev, + localIP: localIP, + activePorts: make(map[uint16]bool), + nsCtx: ctx, + nsCancel: cancel, + sharedBind: sharedBind, + holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds + holepunchTimeout: 3 * time.Second, + holepunchEndpoints: make(map[int]string), + holepunchStatus: make(map[int]bool), } if err := pm.initNetstack(); err != nil { logger.Error("Failed to initialize netstack for peer monitor: %v", err) } + // Initialize holepunch tester if sharedBind is available + if sharedBind != nil { + pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind) + } + return pm } @@ -209,6 +234,8 @@ func (pm *PeerMonitor) Start() { } logger.Info("Started monitoring peer %d\n", siteID) } + + pm.startHolepunchMonitor() } // handleConnectionStatusChange is called when a peer's connection status changes @@ -282,6 +309,9 @@ func (pm *PeerMonitor) sendRelay(siteID int) error { // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { + // Stop holepunch monitor first (outside of mutex to avoid deadlock) + pm.stopHolepunchMonitor() + pm.mutex.Lock() defer pm.mutex.Unlock() @@ -297,8 +327,148 @@ func (pm *PeerMonitor) Stop() { } } +// SetHolepunchStatusCallback sets the callback for holepunch status changes +func (pm *PeerMonitor) SetHolepunchStatusCallback(callback HolepunchStatusCallback) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + pm.holepunchStatusCallback = callback +} + +// AddHolepunchEndpoint adds an endpoint to monitor via holepunch magic packets +func (pm *PeerMonitor) AddHolepunchEndpoint(siteID int, endpoint string) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.holepunchEndpoints[siteID] = endpoint + pm.holepunchStatus[siteID] = false // Initially unknown/disconnected + logger.Info("Added holepunch monitoring for site %d at %s", siteID, endpoint) +} + +// RemoveHolepunchEndpoint removes an endpoint from holepunch monitoring +func (pm *PeerMonitor) RemoveHolepunchEndpoint(siteID int) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + delete(pm.holepunchEndpoints, siteID) + delete(pm.holepunchStatus, siteID) + logger.Info("Removed holepunch monitoring for site %d", siteID) +} + +// startHolepunchMonitor starts the holepunch connection monitoring +// Note: This function assumes the mutex is already held by the caller (called from Start()) +func (pm *PeerMonitor) startHolepunchMonitor() error { + if pm.holepunchTester == nil { + return fmt.Errorf("holepunch tester not initialized (sharedBind not provided)") + } + + if pm.holepunchStopChan != nil { + return fmt.Errorf("holepunch monitor already running") + } + + if err := pm.holepunchTester.Start(); err != nil { + return fmt.Errorf("failed to start holepunch tester: %w", err) + } + + pm.holepunchStopChan = make(chan struct{}) + + go pm.runHolepunchMonitor() + + logger.Info("Started holepunch connection monitor") + return nil +} + +// stopHolepunchMonitor stops the holepunch connection monitoring +func (pm *PeerMonitor) stopHolepunchMonitor() { + pm.mutex.Lock() + stopChan := pm.holepunchStopChan + pm.holepunchStopChan = nil + pm.mutex.Unlock() + + if stopChan != nil { + close(stopChan) + } + + if pm.holepunchTester != nil { + pm.holepunchTester.Stop() + } + + logger.Info("Stopped holepunch connection monitor") +} + +// runHolepunchMonitor runs the holepunch monitoring loop +func (pm *PeerMonitor) runHolepunchMonitor() { + ticker := time.NewTicker(pm.holepunchInterval) + defer ticker.Stop() + + // Do initial check immediately + pm.checkHolepunchEndpoints() + + for { + select { + case <-pm.holepunchStopChan: + return + case <-ticker.C: + pm.checkHolepunchEndpoints() + } + } +} + +// checkHolepunchEndpoints tests all holepunch endpoints +func (pm *PeerMonitor) checkHolepunchEndpoints() { + pm.mutex.Lock() + endpoints := make(map[int]string, len(pm.holepunchEndpoints)) + for siteID, endpoint := range pm.holepunchEndpoints { + endpoints[siteID] = endpoint + } + timeout := pm.holepunchTimeout + pm.mutex.Unlock() + + for siteID, endpoint := range endpoints { + result := pm.holepunchTester.TestEndpoint(endpoint, timeout) + + pm.mutex.Lock() + previousStatus, exists := pm.holepunchStatus[siteID] + pm.holepunchStatus[siteID] = result.Success + callback := pm.holepunchStatusCallback + pm.mutex.Unlock() + + // Log status changes + if !exists || previousStatus != result.Success { + if result.Success { + logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) + } else { + if result.Error != nil { + logger.Warn("Holepunch to site %d (%s) is DISCONNECTED: %v", siteID, endpoint, result.Error) + } else { + logger.Warn("Holepunch to site %d (%s) is DISCONNECTED", siteID, endpoint) + } + } + } + + // Call the callback if set + if callback != nil { + callback(siteID, endpoint, result.Success, result.RTT) + } + } +} + +// GetHolepunchStatus returns the current holepunch status for all endpoints +func (pm *PeerMonitor) GetHolepunchStatus() map[int]bool { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + status := make(map[int]bool, len(pm.holepunchStatus)) + for siteID, connected := range pm.holepunchStatus { + status[siteID] = connected + } + return status +} + // Close stops monitoring and cleans up resources func (pm *PeerMonitor) Close() { + // Stop holepunch monitor first (outside of mutex to avoid deadlock) + pm.stopHolepunchMonitor() + pm.mutex.Lock() defer pm.mutex.Unlock()