mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Shutting down correct now
This commit is contained in:
@@ -2,8 +2,10 @@ package device
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
@@ -50,10 +52,13 @@ func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
func (d *MiddleDevice) pump() {
|
||||
const defaultOffset = 16
|
||||
batchSize := d.Device.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started")
|
||||
|
||||
for {
|
||||
// Check closed first with priority
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
||||
return
|
||||
default:
|
||||
}
|
||||
@@ -69,13 +74,24 @@ func (d *MiddleDevice) pump() {
|
||||
|
||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||
|
||||
// Check closed again after read returns
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Now try to send the result
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Debug("MiddleDevice: pump exiting due to read error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -116,10 +132,16 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
func (d *MiddleDevice) Close() error {
|
||||
select {
|
||||
case <-d.closed:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
||||
close(d.closed)
|
||||
}
|
||||
return d.Device.Close()
|
||||
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
||||
err := d.Device.Close()
|
||||
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
@@ -154,9 +176,19 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
// Check if already closed first (non-blocking)
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
||||
return 0, os.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
// Now block waiting for data
|
||||
select {
|
||||
case res := <-d.readCh:
|
||||
if res.err != nil {
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
}
|
||||
|
||||
@@ -196,7 +228,8 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
|
||||
n = 1
|
||||
|
||||
case <-d.closed:
|
||||
return 0, nil // Device closed
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
||||
return 0, os.ErrClosed // Signal that device is closed
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
|
||||
@@ -124,14 +124,17 @@ func (p *DNSProxy) Stop() {
|
||||
p.middleDevice.RemoveRule(p.proxyIP)
|
||||
}
|
||||
p.cancel()
|
||||
|
||||
// Close the endpoint first to unblock any pending Read() calls in runPacketSender
|
||||
if p.ep != nil {
|
||||
p.ep.Close()
|
||||
}
|
||||
|
||||
p.wg.Wait()
|
||||
|
||||
if p.stack != nil {
|
||||
p.stack.Close()
|
||||
}
|
||||
if p.ep != nil {
|
||||
p.ep.Close()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy stopped")
|
||||
}
|
||||
|
||||
BIN
olm-binary
BIN
olm-binary
Binary file not shown.
28
olm/olm.go
28
olm/olm.go
@@ -839,28 +839,24 @@ func Close() {
|
||||
uapiListener = nil
|
||||
}
|
||||
|
||||
// Close TUN device first to unblock any reads
|
||||
logger.Debug("Closing TUN device")
|
||||
if tdev != nil {
|
||||
tdev.Close()
|
||||
tdev = nil
|
||||
}
|
||||
|
||||
// Close filtered device (this will close the closed channel and stop pump goroutine)
|
||||
logger.Debug("Closing MiddleDevice")
|
||||
if middleDev != nil {
|
||||
middleDev.Close()
|
||||
middleDev = nil
|
||||
}
|
||||
|
||||
// Stop DNS proxy
|
||||
// Stop DNS proxy first - it uses the middleDev for packet filtering
|
||||
logger.Debug("Stopping DNS proxy")
|
||||
if dnsProxy != nil {
|
||||
dnsProxy.Stop()
|
||||
dnsProxy = nil
|
||||
}
|
||||
|
||||
// Now close WireGuard device
|
||||
// Close MiddleDevice first - this closes the TUN and signals the closed channel
|
||||
// This unblocks the pump goroutine and allows WireGuard's TUN reader to exit
|
||||
logger.Debug("Closing MiddleDevice")
|
||||
if middleDev != nil {
|
||||
middleDev.Close()
|
||||
middleDev = nil
|
||||
}
|
||||
// Note: tdev is closed by middleDev.Close() since middleDev wraps it
|
||||
tdev = nil
|
||||
|
||||
// Now close WireGuard device - its TUN reader should have exited by now
|
||||
logger.Debug("Closing WireGuard device")
|
||||
if dev != nil {
|
||||
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
||||
|
||||
Reference in New Issue
Block a user