mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
Try to make the tun replacable
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
package device
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -18,14 +21,68 @@ type FilterRule struct {
|
||||
Handler PacketHandler
|
||||
}
|
||||
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
type MiddleDevice struct {
|
||||
// closeAwareDevice wraps a tun.Device along with a flag
|
||||
// indicating whether its Close method was called.
|
||||
type closeAwareDevice struct {
|
||||
isClosed atomic.Bool
|
||||
tun.Device
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed chan struct{}
|
||||
closeEventCh chan struct{}
|
||||
wg sync.WaitGroup
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||
return &closeAwareDevice{
|
||||
Device: tunDevice,
|
||||
isClosed: atomic.Bool{},
|
||||
closeEventCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||
// to the given channel.
|
||||
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-c.Device.Events():
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if ev == tun.EventDown {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case out <- ev:
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
case <-c.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Close calls the underlying Device's Close method
|
||||
// after setting isClosed to true.
|
||||
func (c *closeAwareDevice) Close() (err error) {
|
||||
c.closeOnce.Do(func() {
|
||||
c.isClosed.Store(true)
|
||||
close(c.closeEventCh)
|
||||
err = c.Device.Close()
|
||||
c.wg.Wait()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *closeAwareDevice) IsClosed() bool {
|
||||
return c.isClosed.Load()
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
@@ -36,58 +93,124 @@ type readResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
// and supports swapping the underlying device.
|
||||
type MiddleDevice struct {
|
||||
devices []*closeAwareDevice
|
||||
mu sync.Mutex
|
||||
cond *sync.Cond
|
||||
rules []FilterRule
|
||||
rulesMutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed atomic.Bool
|
||||
events chan tun.Event
|
||||
}
|
||||
|
||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
d := &MiddleDevice{
|
||||
Device: device,
|
||||
devices: make([]*closeAwareDevice, 0),
|
||||
rules: make([]FilterRule, 0),
|
||||
readCh: make(chan readResult),
|
||||
readCh: make(chan readResult, 16),
|
||||
injectCh: make(chan []byte, 100),
|
||||
closed: make(chan struct{}),
|
||||
events: make(chan tun.Event, 16),
|
||||
}
|
||||
go d.pump()
|
||||
d.cond = sync.NewCond(&d.mu)
|
||||
|
||||
if device != nil {
|
||||
d.AddDevice(device)
|
||||
}
|
||||
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump() {
|
||||
// AddDevice adds a new underlying TUN device, closing any previous one
|
||||
func (d *MiddleDevice) AddDevice(device tun.Device) {
|
||||
d.mu.Lock()
|
||||
if d.closed.Load() {
|
||||
d.mu.Unlock()
|
||||
_ = device.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var toClose *closeAwareDevice
|
||||
if len(d.devices) > 0 {
|
||||
toClose = d.devices[len(d.devices)-1]
|
||||
}
|
||||
|
||||
cad := newCloseAwareDevice(device)
|
||||
cad.redirectEvents(d.events)
|
||||
|
||||
d.devices = []*closeAwareDevice{cad}
|
||||
|
||||
// Start pump for the new device
|
||||
go d.pump(cad)
|
||||
|
||||
d.cond.Broadcast()
|
||||
d.mu.Unlock()
|
||||
|
||||
if toClose != nil {
|
||||
logger.Debug("MiddleDevice: Closing previous device")
|
||||
if err := toClose.Close(); err != nil {
|
||||
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
|
||||
const defaultOffset = 16
|
||||
batchSize := d.Device.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started")
|
||||
batchSize := dev.BatchSize()
|
||||
logger.Debug("MiddleDevice: pump started for device")
|
||||
|
||||
for {
|
||||
// Check closed first with priority
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
||||
// Check if this device is closed
|
||||
if dev.IsClosed() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device is closed")
|
||||
return
|
||||
}
|
||||
|
||||
// Check if MiddleDevice itself is closed
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Allocate buffers for reading
|
||||
// We allocate new buffers for each read to avoid race conditions
|
||||
// since we pass them to the channel
|
||||
bufs := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||
}
|
||||
|
||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||
n, err := dev.Read(bufs, sizes, defaultOffset)
|
||||
|
||||
// Check closed again after read returns
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
||||
// Check if device was closed during read
|
||||
if dev.IsClosed() {
|
||||
logger.Debug("MiddleDevice: pump exiting, device closed during read")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Now try to send the result
|
||||
// Check if MiddleDevice was closed during read
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
|
||||
return
|
||||
}
|
||||
|
||||
// Try to send the result
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
||||
return
|
||||
default:
|
||||
// Channel full, check if we should exit
|
||||
if dev.IsClosed() || d.closed.Load() {
|
||||
return
|
||||
}
|
||||
// Try again with blocking
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-dev.closeEventCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@@ -99,16 +222,21 @@ func (d *MiddleDevice) pump() {
|
||||
|
||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||
if d.closed.Load() {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case d.injectCh <- packet:
|
||||
case <-d.closed:
|
||||
default:
|
||||
// Channel full, drop packet
|
||||
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule adds a packet filtering rule
|
||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rulesMutex.Lock()
|
||||
defer d.rulesMutex.Unlock()
|
||||
d.rules = append(d.rules, FilterRule{
|
||||
DestIP: destIP,
|
||||
Handler: handler,
|
||||
@@ -117,8 +245,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
|
||||
// RemoveRule removes all rules for a given destination IP
|
||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rulesMutex.Lock()
|
||||
defer d.rulesMutex.Unlock()
|
||||
newRules := make([]FilterRule, 0, len(d.rules))
|
||||
for _, rule := range d.rules {
|
||||
if rule.DestIP != destIP {
|
||||
@@ -130,18 +258,113 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
|
||||
// Close stops the device
|
||||
func (d *MiddleDevice) Close() error {
|
||||
select {
|
||||
case <-d.closed:
|
||||
// Already closed
|
||||
return nil
|
||||
default:
|
||||
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
||||
close(d.closed)
|
||||
if !d.closed.CompareAndSwap(false, true) {
|
||||
return nil // already closed
|
||||
}
|
||||
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
||||
err := d.Device.Close()
|
||||
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
||||
return err
|
||||
|
||||
d.mu.Lock()
|
||||
devices := d.devices
|
||||
d.devices = nil
|
||||
d.cond.Broadcast()
|
||||
d.mu.Unlock()
|
||||
|
||||
var lastErr error
|
||||
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
|
||||
for _, device := range devices {
|
||||
if err := device.Close(); err != nil {
|
||||
logger.Debug("MiddleDevice: Error closing device: %v", err)
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
|
||||
close(d.events)
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// Events returns the events channel
|
||||
func (d *MiddleDevice) Events() <-chan tun.Event {
|
||||
return d.events
|
||||
}
|
||||
|
||||
// File returns the underlying file descriptor
|
||||
func (d *MiddleDevice) File() *os.File {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
file := dev.File()
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return file
|
||||
}
|
||||
}
|
||||
|
||||
// MTU returns the MTU of the underlying device
|
||||
func (d *MiddleDevice) MTU() (int, error) {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
mtu, err := dev.MTU()
|
||||
if err == nil {
|
||||
return mtu, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the name of the underlying device
|
||||
func (d *MiddleDevice) Name() (string, error) {
|
||||
for {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return "", io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
name, err := dev.Name()
|
||||
if err == nil {
|
||||
return name, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
// BatchSize returns the batch size
|
||||
func (d *MiddleDevice) BatchSize() int {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
return 1
|
||||
}
|
||||
return dev.BatchSize()
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
@@ -176,156 +399,231 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
// Check if already closed first (non-blocking)
|
||||
select {
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
||||
return 0, os.ErrClosed
|
||||
default:
|
||||
}
|
||||
|
||||
// Now block waiting for data
|
||||
select {
|
||||
case res := <-d.readCh:
|
||||
if res.err != nil {
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
// Handle offset mismatch if necessary
|
||||
// We assume the pump used defaultOffset (16)
|
||||
// If caller asks for different offset, we need to shift
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
// Calculate where the packet data starts and ends in src
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
// Ensure dest buffer is large enough
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue // Skip if buffer too small
|
||||
// Wait for a device to be available
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt := <-d.injectCh:
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil // Buffer too small
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
|
||||
case <-d.closed:
|
||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
||||
return 0, os.ErrClosed // Signal that device is closed
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
// Now block waiting for data from readCh or injectCh
|
||||
select {
|
||||
case res := <-d.readCh:
|
||||
if res.err != nil {
|
||||
// Check if device was swapped
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||
return 0, res.err
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt := <-d.injectCh:
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
}
|
||||
|
||||
// Apply filtering rules
|
||||
d.rulesMutex.RLock()
|
||||
rules := d.rules
|
||||
d.rulesMutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
// Keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
if !handled {
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return writeIdx, err
|
||||
return writeIdx, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if len(rules) == 0 {
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// Filter packets going down
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
d.rulesMutex.RLock()
|
||||
rules := d.rules
|
||||
d.rulesMutex.RUnlock()
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
var filteredBufs [][]byte
|
||||
if len(rules) == 0 {
|
||||
filteredBufs = bufs
|
||||
} else {
|
||||
filteredBufs = make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil // All packets were handled
|
||||
}
|
||||
n, err := dev.Write(filteredBufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
return d.Device.Write(filteredBufs, offset)
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) waitForDevice() bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
for len(d.devices) == 0 && !d.closed.Load() {
|
||||
d.cond.Wait()
|
||||
}
|
||||
return !d.closed.Load()
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) peekLast() *closeAwareDevice {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
|
||||
if len(d.devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return d.devices[len(d.devices)-1]
|
||||
}
|
||||
|
||||
// WriteToTun writes packets directly to the underlying TUN device,
|
||||
// bypassing WireGuard. This is useful for sending packets that should
|
||||
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
|
||||
// Unlike Write(), this does not go through packet filtering rules.
|
||||
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
|
||||
for {
|
||||
if d.closed.Load() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
dev := d.peekLast()
|
||||
if dev == nil {
|
||||
if !d.waitForDevice() {
|
||||
return 0, io.EOF
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := dev.Write(bufs, offset)
|
||||
if err == nil {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
if dev.IsClosed() {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
@@ -5,6 +5,7 @@ package device
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.org/x/sys/unix"
|
||||
@@ -13,6 +14,11 @@ import (
|
||||
)
|
||||
|
||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||
if runtime.GOOS == "android" { // otherwise we get a permission denied
|
||||
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
|
||||
return theTun, err
|
||||
}
|
||||
|
||||
dupTunFd, err := unix.Dup(int(tunFd))
|
||||
if err != nil {
|
||||
logger.Error("Unable to dup tun fd: %v", err)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/device"
|
||||
"github.com/miekg/dns"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
@@ -36,8 +35,7 @@ type DNSProxy struct {
|
||||
upstreamDNS []string
|
||||
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
||||
mtu int
|
||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
|
||||
recordStore *DNSRecordStore // Local DNS records
|
||||
|
||||
// Tunnel DNS fields - for sending queries over WireGuard
|
||||
@@ -53,7 +51,7 @@ type DNSProxy struct {
|
||||
}
|
||||
|
||||
// NewDNSProxy creates a new DNS proxy
|
||||
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
||||
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||
@@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
|
||||
proxy := &DNSProxy{
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
tunDevice: tunDevice,
|
||||
middleDevice: middleDevice,
|
||||
upstreamDNS: upstreamDns,
|
||||
tunnelDNS: tunnelDns,
|
||||
@@ -694,9 +691,9 @@ func (p *DNSProxy) runPacketSender() {
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Write packet to TUN device
|
||||
// Write packet to TUN device via MiddleDevice
|
||||
// offset=16 indicates packet data starts at position 16 in the buffer
|
||||
_, err := p.tunDevice.Write([][]byte{buf}, offset)
|
||||
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||
}
|
||||
|
||||
53
olm/olm.go
53
olm/olm.go
@@ -35,6 +35,7 @@ var (
|
||||
uapiListener net.Listener
|
||||
tdev tun.Device
|
||||
middleDev *olmDevice.MiddleDevice
|
||||
interfaceName string
|
||||
dnsProxy *dns.DNSProxy
|
||||
apiServer *api.API
|
||||
olmClient *websocket.Client
|
||||
@@ -237,11 +238,11 @@ func StartTunnel(config TunnelConfig) {
|
||||
stopPing = make(chan struct{})
|
||||
|
||||
var (
|
||||
interfaceName = config.InterfaceName
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
userToken = config.UserToken
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
userToken = config.UserToken
|
||||
)
|
||||
interfaceName = config.InterfaceName
|
||||
|
||||
apiServer.SetOrgID(config.OrgID)
|
||||
|
||||
@@ -307,12 +308,7 @@ func StartTunnel(config TunnelConfig) {
|
||||
|
||||
tdev, err = func() (tun.Device, error) {
|
||||
if config.FileDescriptorTun != 0 {
|
||||
if runtime.GOOS == "android" { // otherwise we get a permission denied
|
||||
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(config.FileDescriptorTun))
|
||||
return theTun, err
|
||||
} else {
|
||||
return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU)
|
||||
}
|
||||
return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU)
|
||||
}
|
||||
var ifName = interfaceName
|
||||
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||
@@ -329,11 +325,11 @@ func StartTunnel(config TunnelConfig) {
|
||||
return
|
||||
}
|
||||
|
||||
if config.FileDescriptorTun == 0 {
|
||||
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
// if config.FileDescriptorTun == 0 {
|
||||
if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
// }
|
||||
|
||||
// Wrap TUN device with packet filter for DNS proxy
|
||||
middleDev = olmDevice.NewMiddleDevice(tdev)
|
||||
@@ -389,7 +385,7 @@ func StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
|
||||
// Create and start DNS proxy
|
||||
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP)
|
||||
dnsProxy, err = dns.NewDNSProxy(middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS proxy: %v", err)
|
||||
}
|
||||
@@ -956,6 +952,33 @@ func StartTunnel(config TunnelConfig) {
|
||||
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||
}
|
||||
|
||||
func AddDevice(fd uint32) {
|
||||
if middleDev == nil {
|
||||
logger.Error("MiddleDevice is nil, cannot add device")
|
||||
return
|
||||
}
|
||||
|
||||
if tunnelConfig.MTU == 0 {
|
||||
logger.Error("No MTU configured, cannot create device")
|
||||
return
|
||||
}
|
||||
|
||||
tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU)
|
||||
|
||||
if err != nil {
|
||||
logger.Error("Failed to create TUN device: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// if config.FileDescriptorTun == 0 {
|
||||
if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||
interfaceName = realInterfaceName
|
||||
}
|
||||
|
||||
// Here we replace the existing TUN device in the middle device with the new one
|
||||
middleDev.AddDevice(tdev)
|
||||
}
|
||||
|
||||
func Close() {
|
||||
// Restore original DNS configuration
|
||||
// we do this first to avoid any DNS issues if something else gets stuck
|
||||
|
||||
Reference in New Issue
Block a user