diff --git a/olm/olm.go b/olm/olm.go index 3dce73a..25a3bea 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -1095,7 +1095,7 @@ func Close() { } if peerMonitor != nil { - peerMonitor.Stop() + peerMonitor.Close() // Close() also calls Stop() internally peerMonitor = nil } @@ -1104,26 +1104,32 @@ func Close() { uapiListener = nil } - if dev != nil { - dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference - dev = nil + // Close TUN device first to unblock any reads + logger.Debug("Closing TUN device") + if tdev != nil { + tdev.Close() + tdev = nil + } + + // Close filtered device (this will close the closed channel and stop pump goroutine) + logger.Debug("Closing MiddleDevice") + if middleDev != nil { + middleDev.Close() + middleDev = nil } // Stop DNS proxy + logger.Debug("Stopping DNS proxy") if dnsProxy != nil { dnsProxy.Stop() dnsProxy = nil } - // Clear filtered device - if middleDev != nil { - middleDev = nil - } - - // Close TUN device - if tdev != nil { - tdev.Close() - tdev = nil + // Now close WireGuard device + logger.Debug("Closing WireGuard device") + if dev != nil { + dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference + dev = nil } // Release the hole punch reference to the shared bind diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index 4abdb6d..4233238 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -302,14 +302,53 @@ func (pm *PeerMonitor) Close() { pm.mutex.Lock() defer pm.mutex.Unlock() - // Stop and close all clients + logger.Debug("PeerMonitor: Starting cleanup") + + // Stop and close all clients first for siteID, client := range pm.monitors { + logger.Debug("PeerMonitor: Stopping client for site %d", siteID) client.StopMonitor() client.Close() delete(pm.monitors, siteID) } pm.running = false + + // Clean up netstack resources + logger.Debug("PeerMonitor: Cancelling netstack context") + if pm.nsCancel != nil { + pm.nsCancel() // Signal goroutines to stop + } + + // Close the channel endpoint to unblock any pending reads + logger.Debug("PeerMonitor: Closing endpoint") + if pm.ep != nil { + pm.ep.Close() + } + + // Wait for packet sender goroutine to finish with timeout + logger.Debug("PeerMonitor: Waiting for goroutines to finish") + done := make(chan struct{}) + go func() { + pm.nsWg.Wait() + close(done) + }() + + select { + case <-done: + logger.Debug("PeerMonitor: Goroutines finished cleanly") + case <-time.After(2 * time.Second): + logger.Warn("PeerMonitor: Timeout waiting for goroutines to finish, proceeding anyway") + } + + // Destroy the stack last, after all goroutines are done + logger.Debug("PeerMonitor: Destroying stack") + if pm.stack != nil { + pm.stack.Destroy() + pm.stack = nil + } + + logger.Debug("PeerMonitor: Cleanup complete") } // TestPeer tests connectivity to a specific peer @@ -463,40 +502,56 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool { // runPacketSender reads packets from netstack and injects them into WireGuard func (pm *PeerMonitor) runPacketSender() { defer pm.nsWg.Done() + logger.Debug("PeerMonitor: Packet sender goroutine started") + + // Use a ticker to periodically check for packets without blocking indefinitely + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() for { select { case <-pm.nsCtx.Done(): + logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") + // Drain any remaining packets before exiting + for { + pkt := pm.ep.Read() + if pkt == nil { + break + } + pkt.DecRef() + } + logger.Debug("PeerMonitor: Packet sender goroutine exiting") return - default: - } + case <-ticker.C: + // Try to read packets in batches + for i := 0; i < 10; i++ { + pkt := pm.ep.Read() + if pkt == nil { + break + } - pkt := pm.ep.Read() - if pkt == nil { - time.Sleep(1 * time.Millisecond) - continue - } + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } - // Extract packet data - slices := pkt.AsSlices() - if len(slices) > 0 { - var totalSize int - for _, slice := range slices { - totalSize += len(slice) + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + pm.middleDev.InjectOutbound(buf) + } + + pkt.DecRef() } - - buf := make([]byte, totalSize) - pos := 0 - for _, slice := range slices { - copy(buf[pos:], slice) - pos += len(slice) - } - - // Inject into MiddleDevice (outbound to WG) - pm.middleDev.InjectOutbound(buf) } - - pkt.DecRef() } } @@ -569,5 +624,8 @@ type trackedConn struct { func (c *trackedConn) Close() error { c.pm.removePort(c.port) - return c.Conn.Close() + if c.Conn != nil { + return c.Conn.Close() + } + return nil }