Cleanup working

This commit is contained in:
Owen
2025-11-23 16:58:06 -05:00
parent a319baa298
commit d107e2d7de
2 changed files with 104 additions and 40 deletions

View File

@@ -1095,7 +1095,7 @@ func Close() {
}
if peerMonitor != nil {
peerMonitor.Stop()
peerMonitor.Close() // Close() also calls Stop() internally
peerMonitor = nil
}
@@ -1104,26 +1104,32 @@ func Close() {
uapiListener = nil
}
if dev != nil {
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
dev = nil
// Close TUN device first to unblock any reads
logger.Debug("Closing TUN device")
if tdev != nil {
tdev.Close()
tdev = nil
}
// Close filtered device (this will close the closed channel and stop pump goroutine)
logger.Debug("Closing MiddleDevice")
if middleDev != nil {
middleDev.Close()
middleDev = nil
}
// Stop DNS proxy
logger.Debug("Stopping DNS proxy")
if dnsProxy != nil {
dnsProxy.Stop()
dnsProxy = nil
}
// Clear filtered device
if middleDev != nil {
middleDev = nil
}
// Close TUN device
if tdev != nil {
tdev.Close()
tdev = nil
// Now close WireGuard device
logger.Debug("Closing WireGuard device")
if dev != nil {
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
dev = nil
}
// Release the hole punch reference to the shared bind

View File

@@ -302,14 +302,53 @@ func (pm *PeerMonitor) Close() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
// Stop and close all clients
logger.Debug("PeerMonitor: Starting cleanup")
// Stop and close all clients first
for siteID, client := range pm.monitors {
logger.Debug("PeerMonitor: Stopping client for site %d", siteID)
client.StopMonitor()
client.Close()
delete(pm.monitors, siteID)
}
pm.running = false
// Clean up netstack resources
logger.Debug("PeerMonitor: Cancelling netstack context")
if pm.nsCancel != nil {
pm.nsCancel() // Signal goroutines to stop
}
// Close the channel endpoint to unblock any pending reads
logger.Debug("PeerMonitor: Closing endpoint")
if pm.ep != nil {
pm.ep.Close()
}
// Wait for packet sender goroutine to finish with timeout
logger.Debug("PeerMonitor: Waiting for goroutines to finish")
done := make(chan struct{})
go func() {
pm.nsWg.Wait()
close(done)
}()
select {
case <-done:
logger.Debug("PeerMonitor: Goroutines finished cleanly")
case <-time.After(2 * time.Second):
logger.Warn("PeerMonitor: Timeout waiting for goroutines to finish, proceeding anyway")
}
// Destroy the stack last, after all goroutines are done
logger.Debug("PeerMonitor: Destroying stack")
if pm.stack != nil {
pm.stack.Destroy()
pm.stack = nil
}
logger.Debug("PeerMonitor: Cleanup complete")
}
// TestPeer tests connectivity to a specific peer
@@ -463,40 +502,56 @@ func (pm *PeerMonitor) handlePacket(packet []byte) bool {
// runPacketSender reads packets from netstack and injects them into WireGuard
func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done()
logger.Debug("PeerMonitor: Packet sender goroutine started")
// Use a ticker to periodically check for packets without blocking indefinitely
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-pm.nsCtx.Done():
logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets")
// Drain any remaining packets before exiting
for {
pkt := pm.ep.Read()
if pkt == nil {
break
}
pkt.DecRef()
}
logger.Debug("PeerMonitor: Packet sender goroutine exiting")
return
default:
}
case <-ticker.C:
// Try to read packets in batches
for i := 0; i < 10; i++ {
pkt := pm.ep.Read()
if pkt == nil {
break
}
pkt := pm.ep.Read()
if pkt == nil {
time.Sleep(1 * time.Millisecond)
continue
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
}
@@ -569,5 +624,8 @@ type trackedConn struct {
func (c *trackedConn) Close() error {
c.pm.removePort(c.port)
return c.Conn.Close()
if c.Conn != nil {
return c.Conn.Close()
}
return nil
}