diff --git a/peers/manager.go b/peers/manager.go index f8d468d..78681e1 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -149,6 +149,11 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { } pm.peers[siteConfig.SiteId] = siteConfig + + // Perform rapid initial holepunch test (outside of lock to avoid blocking) + // This quickly determines if holepunch is viable and triggers relay if not + go pm.performRapidInitialTest(siteConfig.SiteId, siteConfig.Endpoint) + return nil } @@ -708,6 +713,28 @@ endpoint=%s:21820`, util.FixKey(peer.PublicKey), formattedEndpoint) logger.Info("Adjusted peer %d to point to relay!\n", siteId) } +// performRapidInitialTest performs a rapid holepunch test for a newly added peer. +// If the test fails, it immediately requests relay to minimize connection delay. +// This runs in a goroutine to avoid blocking AddPeer. +func (pm *PeerManager) performRapidInitialTest(siteId int, endpoint string) { + if pm.peerMonitor == nil { + return + } + + // Perform rapid test - this takes ~1-2 seconds max + holepunchViable := pm.peerMonitor.RapidTestPeer(siteId, endpoint) + + if !holepunchViable { + // Holepunch failed rapid test, request relay immediately + logger.Info("Rapid test failed for site %d, requesting relay", siteId) + if err := pm.peerMonitor.RequestRelay(siteId); err != nil { + logger.Error("Failed to request relay for site %d: %v", siteId, err) + } + } else { + logger.Info("Rapid test passed for site %d, using direct connection", siteId) + } +} + // Start starts the peer monitor func (pm *PeerManager) Start() { if pm.peerMonitor != nil { diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 215ca72..ac91cb3 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -61,6 +61,11 @@ type PeerMonitor struct { holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchFailures map[int]int // siteID -> consecutive failure count + // Rapid initial test fields + rapidTestInterval time.Duration // interval between rapid test attempts + rapidTestTimeout time.Duration // timeout for each rapid test attempt + rapidTestMaxAttempts int // max attempts during rapid test phase + // API server for status updates apiServer *api.API @@ -73,8 +78,8 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - interval: 3 * time.Second, // Default check interval - timeout: 5 * time.Second, + interval: 2 * time.Second, // Default check interval (faster) + timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, middleDev: middleDev, @@ -83,13 +88,17 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 3 * time.Second, // Check holepunch every 5 seconds - holepunchTimeout: 5 * time.Second, + holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds + holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), relayedPeers: make(map[int]bool), - holepunchMaxAttempts: 3, // Trigger relay after 5 consecutive failures + holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchFailures: make(map[int]int), + // Rapid initial test settings: complete within ~1.5 seconds + rapidTestInterval: 200 * time.Millisecond, // 200ms between attempts + rapidTestTimeout: 400 * time.Millisecond, // 400ms timeout per attempt + rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total apiServer: apiServer, wgConnectionStatus: make(map[int]bool), } @@ -182,10 +191,63 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st // update holepunch endpoint for a peer func (pm *PeerMonitor) UpdateHolepunchEndpoint(siteID int, endpoint string) { - pm.mutex.Lock() - defer pm.mutex.Unlock() + go func() { + time.Sleep(3 * time.Second) + pm.mutex.Lock() + defer pm.mutex.Unlock() + pm.holepunchEndpoints[siteID] = endpoint + }() +} - pm.holepunchEndpoints[siteID] = endpoint +// RapidTestPeer performs a rapid connectivity test for a newly added peer. +// This is designed to quickly determine if holepunch is viable within ~1-2 seconds. +// Returns true if the connection is viable (holepunch works), false if it should relay. +func (pm *PeerMonitor) RapidTestPeer(siteID int, endpoint string) bool { + if pm.holepunchTester == nil { + logger.Warn("Cannot perform rapid test: holepunch tester not initialized") + return false + } + + pm.mutex.Lock() + interval := pm.rapidTestInterval + timeout := pm.rapidTestTimeout + maxAttempts := pm.rapidTestMaxAttempts + pm.mutex.Unlock() + + logger.Info("Starting rapid holepunch test for site %d at %s (max %d attempts, %v timeout each)", + siteID, endpoint, maxAttempts, timeout) + + for attempt := 1; attempt <= maxAttempts; attempt++ { + result := pm.holepunchTester.TestEndpoint(endpoint, timeout) + + if result.Success { + logger.Info("Rapid test: site %d holepunch SUCCEEDED on attempt %d (RTT: %v)", + siteID, attempt, result.RTT) + + // Update status + pm.mutex.Lock() + pm.holepunchStatus[siteID] = true + pm.holepunchFailures[siteID] = 0 + pm.mutex.Unlock() + + return true + } + + if attempt < maxAttempts { + time.Sleep(interval) + } + } + + logger.Warn("Rapid test: site %d holepunch FAILED after %d attempts, will relay", + siteID, maxAttempts) + + // Update status to reflect failure + pm.mutex.Lock() + pm.holepunchStatus[siteID] = false + pm.holepunchFailures[siteID] = maxAttempts + pm.mutex.Unlock() + + return false } // UpdatePeerEndpoint updates the monitor endpoint for a peer @@ -300,7 +362,13 @@ func (pm *PeerMonitor) sendRelay(siteID int) error { return nil } -// sendRelay sends a relay message to the server +// RequestRelay is a public method to request relay for a peer. +// This is used when rapid initial testing determines holepunch is not viable. +func (pm *PeerMonitor) RequestRelay(siteID int) error { + return pm.sendRelay(siteID) +} + +// sendUnRelay sends an unrelay message to the server func (pm *PeerMonitor) sendUnRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") @@ -431,6 +499,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() { pm.mutex.Unlock() for siteID, endpoint := range endpoints { + logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index 6204620..dac2008 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -157,14 +157,14 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { return false, 0 } - logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) + // logger.Debug("Attempting to send monitor packet to %s", c.serverAddr) _, err := c.conn.Write(packet) if err != nil { c.connLock.Unlock() logger.Info("Error sending packet: %v", err) continue } - logger.Debug("Successfully sent monitor packet") + // logger.Debug("Successfully sent monitor packet") // Set read deadline c.conn.SetReadDeadline(time.Now().Add(c.timeout)) diff --git a/peers/peer.go b/peers/peer.go index 3e1b8d5..9370b9d 100644 --- a/peers/peer.go +++ b/peers/peer.go @@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=1\n") + configBuilder.WriteString("persistent_keepalive_interval=5\n") config := configBuilder.String() logger.Debug("Configuring peer with config: %s", config) diff --git a/service_windows.go b/service_windows.go index dc941f3..c103c46 100644 --- a/service_windows.go +++ b/service_windows.go @@ -163,6 +163,9 @@ func (s *olmService) runOlm() { // Create a context that can be cancelled when the service stops s.ctx, s.stop = context.WithCancel(context.Background()) + // Create a separate context for programmatic shutdown (e.g., via API exit) + ctx, cancel := context.WithCancel(context.Background()) + // Setup logging for service mode s.elog.Info(1, "Starting Olm main logic") @@ -177,7 +180,8 @@ func (s *olmService) runOlm() { }() // Call the main olm function with stored arguments - runOlmMainWithArgs(s.ctx, s.args) + // Use s.ctx as the signal context since the service manages shutdown + runOlmMainWithArgs(ctx, cancel, s.ctx, s.args) }() // Wait for either context cancellation or main logic completion