mirror of
https://github.com/fosrl/olm.git
synced 2026-02-26 06:46:48 +00:00
@@ -2,8 +2,10 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/logger"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -50,10 +52,13 @@ func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
|||||||
func (d *MiddleDevice) pump() {
|
func (d *MiddleDevice) pump() {
|
||||||
const defaultOffset = 16
|
const defaultOffset = 16
|
||||||
batchSize := d.Device.BatchSize()
|
batchSize := d.Device.BatchSize()
|
||||||
|
logger.Debug("MiddleDevice: pump started")
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
// Check closed first with priority
|
||||||
select {
|
select {
|
||||||
case <-d.closed:
|
case <-d.closed:
|
||||||
|
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
@@ -69,13 +74,24 @@ func (d *MiddleDevice) pump() {
|
|||||||
|
|
||||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||||
|
|
||||||
|
// Check closed again after read returns
|
||||||
|
select {
|
||||||
|
case <-d.closed:
|
||||||
|
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now try to send the result
|
||||||
select {
|
select {
|
||||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||||
case <-d.closed:
|
case <-d.closed:
|
||||||
|
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logger.Debug("MiddleDevice: pump exiting due to read error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -116,10 +132,16 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
|||||||
func (d *MiddleDevice) Close() error {
|
func (d *MiddleDevice) Close() error {
|
||||||
select {
|
select {
|
||||||
case <-d.closed:
|
case <-d.closed:
|
||||||
|
// Already closed
|
||||||
|
return nil
|
||||||
default:
|
default:
|
||||||
|
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
||||||
close(d.closed)
|
close(d.closed)
|
||||||
}
|
}
|
||||||
return d.Device.Close()
|
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
||||||
|
err := d.Device.Close()
|
||||||
|
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractDestIP extracts destination IP from packet (fast path)
|
// extractDestIP extracts destination IP from packet (fast path)
|
||||||
@@ -154,9 +176,19 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
|||||||
|
|
||||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
|
// Check if already closed first (non-blocking)
|
||||||
|
select {
|
||||||
|
case <-d.closed:
|
||||||
|
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now block waiting for data
|
||||||
select {
|
select {
|
||||||
case res := <-d.readCh:
|
case res := <-d.readCh:
|
||||||
if res.err != nil {
|
if res.err != nil {
|
||||||
|
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||||
return 0, res.err
|
return 0, res.err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,7 +228,8 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
|
|||||||
n = 1
|
n = 1
|
||||||
|
|
||||||
case <-d.closed:
|
case <-d.closed:
|
||||||
return 0, nil // Device closed
|
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
||||||
|
return 0, os.ErrClosed // Signal that device is closed
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mutex.RLock()
|
d.mutex.RLock()
|
||||||
|
|||||||
@@ -124,14 +124,17 @@ func (p *DNSProxy) Stop() {
|
|||||||
p.middleDevice.RemoveRule(p.proxyIP)
|
p.middleDevice.RemoveRule(p.proxyIP)
|
||||||
}
|
}
|
||||||
p.cancel()
|
p.cancel()
|
||||||
|
|
||||||
|
// Close the endpoint first to unblock any pending Read() calls in runPacketSender
|
||||||
|
if p.ep != nil {
|
||||||
|
p.ep.Close()
|
||||||
|
}
|
||||||
|
|
||||||
p.wg.Wait()
|
p.wg.Wait()
|
||||||
|
|
||||||
if p.stack != nil {
|
if p.stack != nil {
|
||||||
p.stack.Close()
|
p.stack.Close()
|
||||||
}
|
}
|
||||||
if p.ep != nil {
|
|
||||||
p.ep.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("DNS proxy stopped")
|
logger.Info("DNS proxy stopped")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
c94f554cb06ba7952df7cd58d7d8620fd1eddc82
|
|
||||||
28
olm/olm.go
28
olm/olm.go
@@ -839,28 +839,24 @@ func Close() {
|
|||||||
uapiListener = nil
|
uapiListener = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close TUN device first to unblock any reads
|
// Stop DNS proxy first - it uses the middleDev for packet filtering
|
||||||
logger.Debug("Closing TUN device")
|
|
||||||
if tdev != nil {
|
|
||||||
tdev.Close()
|
|
||||||
tdev = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close filtered device (this will close the closed channel and stop pump goroutine)
|
|
||||||
logger.Debug("Closing MiddleDevice")
|
|
||||||
if middleDev != nil {
|
|
||||||
middleDev.Close()
|
|
||||||
middleDev = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop DNS proxy
|
|
||||||
logger.Debug("Stopping DNS proxy")
|
logger.Debug("Stopping DNS proxy")
|
||||||
if dnsProxy != nil {
|
if dnsProxy != nil {
|
||||||
dnsProxy.Stop()
|
dnsProxy.Stop()
|
||||||
dnsProxy = nil
|
dnsProxy = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now close WireGuard device
|
// Close MiddleDevice first - this closes the TUN and signals the closed channel
|
||||||
|
// This unblocks the pump goroutine and allows WireGuard's TUN reader to exit
|
||||||
|
logger.Debug("Closing MiddleDevice")
|
||||||
|
if middleDev != nil {
|
||||||
|
middleDev.Close()
|
||||||
|
middleDev = nil
|
||||||
|
}
|
||||||
|
// Note: tdev is closed by middleDev.Close() since middleDev wraps it
|
||||||
|
tdev = nil
|
||||||
|
|
||||||
|
// Now close WireGuard device - its TUN reader should have exited by now
|
||||||
logger.Debug("Closing WireGuard device")
|
logger.Debug("Closing WireGuard device")
|
||||||
if dev != nil {
|
if dev != nil {
|
||||||
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
dev.Close() // This will call sharedBind.Close() which releases WireGuard's reference
|
||||||
|
|||||||
Reference in New Issue
Block a user