mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
@@ -1,9 +1,12 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"io"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -18,14 +21,68 @@ type FilterRule struct {
|
|||||||
Handler PacketHandler
|
Handler PacketHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
// closeAwareDevice wraps a tun.Device along with a flag
|
||||||
type MiddleDevice struct {
|
// indicating whether its Close method was called.
|
||||||
|
type closeAwareDevice struct {
|
||||||
|
isClosed atomic.Bool
|
||||||
tun.Device
|
tun.Device
|
||||||
rules []FilterRule
|
closeEventCh chan struct{}
|
||||||
mutex sync.RWMutex
|
wg sync.WaitGroup
|
||||||
readCh chan readResult
|
closeOnce sync.Once
|
||||||
injectCh chan []byte
|
}
|
||||||
closed chan struct{}
|
|
||||||
|
func newCloseAwareDevice(tunDevice tun.Device) *closeAwareDevice {
|
||||||
|
return &closeAwareDevice{
|
||||||
|
Device: tunDevice,
|
||||||
|
isClosed: atomic.Bool{},
|
||||||
|
closeEventCh: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirectEvents redirects the Events() method of the underlying tun.Device
|
||||||
|
// to the given channel.
|
||||||
|
func (c *closeAwareDevice) redirectEvents(out chan tun.Event) {
|
||||||
|
c.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer c.wg.Done()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case ev, ok := <-c.Device.Events():
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if ev == tun.EventDown {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case out <- ev:
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-c.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close calls the underlying Device's Close method
|
||||||
|
// after setting isClosed to true.
|
||||||
|
func (c *closeAwareDevice) Close() (err error) {
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.isClosed.Store(true)
|
||||||
|
close(c.closeEventCh)
|
||||||
|
err = c.Device.Close()
|
||||||
|
c.wg.Wait()
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *closeAwareDevice) IsClosed() bool {
|
||||||
|
return c.isClosed.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
type readResult struct {
|
type readResult struct {
|
||||||
@@ -36,58 +93,124 @@ type readResult struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||||
|
// and supports swapping the underlying device.
|
||||||
|
type MiddleDevice struct {
|
||||||
|
devices []*closeAwareDevice
|
||||||
|
mu sync.Mutex
|
||||||
|
cond *sync.Cond
|
||||||
|
rules []FilterRule
|
||||||
|
rulesMutex sync.RWMutex
|
||||||
|
readCh chan readResult
|
||||||
|
injectCh chan []byte
|
||||||
|
closed atomic.Bool
|
||||||
|
events chan tun.Event
|
||||||
|
}
|
||||||
|
|
||||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||||
d := &MiddleDevice{
|
d := &MiddleDevice{
|
||||||
Device: device,
|
devices: make([]*closeAwareDevice, 0),
|
||||||
rules: make([]FilterRule, 0),
|
rules: make([]FilterRule, 0),
|
||||||
readCh: make(chan readResult),
|
readCh: make(chan readResult, 16),
|
||||||
injectCh: make(chan []byte, 100),
|
injectCh: make(chan []byte, 100),
|
||||||
closed: make(chan struct{}),
|
events: make(chan tun.Event, 16),
|
||||||
}
|
}
|
||||||
go d.pump()
|
d.cond = sync.NewCond(&d.mu)
|
||||||
|
|
||||||
|
if device != nil {
|
||||||
|
d.AddDevice(device)
|
||||||
|
}
|
||||||
|
|
||||||
return d
|
return d
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *MiddleDevice) pump() {
|
// AddDevice adds a new underlying TUN device, closing any previous one
|
||||||
|
func (d *MiddleDevice) AddDevice(device tun.Device) {
|
||||||
|
d.mu.Lock()
|
||||||
|
if d.closed.Load() {
|
||||||
|
d.mu.Unlock()
|
||||||
|
_ = device.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var toClose *closeAwareDevice
|
||||||
|
if len(d.devices) > 0 {
|
||||||
|
toClose = d.devices[len(d.devices)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
cad := newCloseAwareDevice(device)
|
||||||
|
cad.redirectEvents(d.events)
|
||||||
|
|
||||||
|
d.devices = []*closeAwareDevice{cad}
|
||||||
|
|
||||||
|
// Start pump for the new device
|
||||||
|
go d.pump(cad)
|
||||||
|
|
||||||
|
d.cond.Broadcast()
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
if toClose != nil {
|
||||||
|
logger.Debug("MiddleDevice: Closing previous device")
|
||||||
|
if err := toClose.Close(); err != nil {
|
||||||
|
logger.Debug("MiddleDevice: Error closing previous device: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MiddleDevice) pump(dev *closeAwareDevice) {
|
||||||
const defaultOffset = 16
|
const defaultOffset = 16
|
||||||
batchSize := d.Device.BatchSize()
|
batchSize := dev.BatchSize()
|
||||||
logger.Debug("MiddleDevice: pump started")
|
logger.Debug("MiddleDevice: pump started for device")
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// Check closed first with priority
|
// Check if this device is closed
|
||||||
select {
|
if dev.IsClosed() {
|
||||||
case <-d.closed:
|
logger.Debug("MiddleDevice: pump exiting, device is closed")
|
||||||
logger.Debug("MiddleDevice: pump exiting due to closed channel")
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if MiddleDevice itself is closed
|
||||||
|
if d.closed.Load() {
|
||||||
|
logger.Debug("MiddleDevice: pump exiting, MiddleDevice is closed")
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocate buffers for reading
|
// Allocate buffers for reading
|
||||||
// We allocate new buffers for each read to avoid race conditions
|
|
||||||
// since we pass them to the channel
|
|
||||||
bufs := make([][]byte, batchSize)
|
bufs := make([][]byte, batchSize)
|
||||||
sizes := make([]int, batchSize)
|
sizes := make([]int, batchSize)
|
||||||
for i := range bufs {
|
for i := range bufs {
|
||||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
n, err := dev.Read(bufs, sizes, defaultOffset)
|
||||||
|
|
||||||
// Check closed again after read returns
|
// Check if device was closed during read
|
||||||
select {
|
if dev.IsClosed() {
|
||||||
case <-d.closed:
|
logger.Debug("MiddleDevice: pump exiting, device closed during read")
|
||||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (after read)")
|
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now try to send the result
|
// Check if MiddleDevice was closed during read
|
||||||
|
if d.closed.Load() {
|
||||||
|
logger.Debug("MiddleDevice: pump exiting, MiddleDevice closed during read")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to send the result
|
||||||
select {
|
select {
|
||||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||||
case <-d.closed:
|
default:
|
||||||
logger.Debug("MiddleDevice: pump exiting due to closed channel (during send)")
|
// Channel full, check if we should exit
|
||||||
return
|
if dev.IsClosed() || d.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Try again with blocking
|
||||||
|
select {
|
||||||
|
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||||
|
case <-dev.closeEventCh:
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -99,16 +222,21 @@ func (d *MiddleDevice) pump() {
|
|||||||
|
|
||||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||||
|
if d.closed.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
select {
|
select {
|
||||||
case d.injectCh <- packet:
|
case d.injectCh <- packet:
|
||||||
case <-d.closed:
|
default:
|
||||||
|
// Channel full, drop packet
|
||||||
|
logger.Debug("MiddleDevice: InjectOutbound dropping packet, channel full")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRule adds a packet filtering rule
|
// AddRule adds a packet filtering rule
|
||||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||||
d.mutex.Lock()
|
d.rulesMutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.rulesMutex.Unlock()
|
||||||
d.rules = append(d.rules, FilterRule{
|
d.rules = append(d.rules, FilterRule{
|
||||||
DestIP: destIP,
|
DestIP: destIP,
|
||||||
Handler: handler,
|
Handler: handler,
|
||||||
@@ -117,8 +245,8 @@ func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
|||||||
|
|
||||||
// RemoveRule removes all rules for a given destination IP
|
// RemoveRule removes all rules for a given destination IP
|
||||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||||
d.mutex.Lock()
|
d.rulesMutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.rulesMutex.Unlock()
|
||||||
newRules := make([]FilterRule, 0, len(d.rules))
|
newRules := make([]FilterRule, 0, len(d.rules))
|
||||||
for _, rule := range d.rules {
|
for _, rule := range d.rules {
|
||||||
if rule.DestIP != destIP {
|
if rule.DestIP != destIP {
|
||||||
@@ -130,18 +258,113 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
|||||||
|
|
||||||
// Close stops the device
|
// Close stops the device
|
||||||
func (d *MiddleDevice) Close() error {
|
func (d *MiddleDevice) Close() error {
|
||||||
select {
|
if !d.closed.CompareAndSwap(false, true) {
|
||||||
case <-d.closed:
|
return nil // already closed
|
||||||
// Already closed
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
logger.Debug("MiddleDevice: Closing, signaling closed channel")
|
|
||||||
close(d.closed)
|
|
||||||
}
|
}
|
||||||
logger.Debug("MiddleDevice: Closing underlying TUN device")
|
|
||||||
err := d.Device.Close()
|
d.mu.Lock()
|
||||||
logger.Debug("MiddleDevice: Underlying TUN device closed, err=%v", err)
|
devices := d.devices
|
||||||
return err
|
d.devices = nil
|
||||||
|
d.cond.Broadcast()
|
||||||
|
d.mu.Unlock()
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
logger.Debug("MiddleDevice: Closing %d devices", len(devices))
|
||||||
|
for _, device := range devices {
|
||||||
|
if err := device.Close(); err != nil {
|
||||||
|
logger.Debug("MiddleDevice: Error closing device: %v", err)
|
||||||
|
lastErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
close(d.events)
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Events returns the events channel
|
||||||
|
func (d *MiddleDevice) Events() <-chan tun.Event {
|
||||||
|
return d.events
|
||||||
|
}
|
||||||
|
|
||||||
|
// File returns the underlying file descriptor
|
||||||
|
func (d *MiddleDevice) File() *os.File {
|
||||||
|
for {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
file := dev.File()
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return file
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MTU returns the MTU of the underlying device
|
||||||
|
func (d *MiddleDevice) MTU() (int, error) {
|
||||||
|
for {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
mtu, err := dev.MTU()
|
||||||
|
if err == nil {
|
||||||
|
return mtu, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the name of the underlying device
|
||||||
|
func (d *MiddleDevice) Name() (string, error) {
|
||||||
|
for {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name, err := dev.Name()
|
||||||
|
if err == nil {
|
||||||
|
return name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BatchSize returns the batch size
|
||||||
|
func (d *MiddleDevice) BatchSize() int {
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return dev.BatchSize()
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractDestIP extracts destination IP from packet (fast path)
|
// extractDestIP extracts destination IP from packet (fast path)
|
||||||
@@ -176,156 +399,231 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
|||||||
|
|
||||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
// Check if already closed first (non-blocking)
|
for {
|
||||||
select {
|
if d.closed.Load() {
|
||||||
case <-d.closed:
|
logger.Debug("MiddleDevice: Read returning io.EOF, device closed")
|
||||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed (pre-check)")
|
return 0, io.EOF
|
||||||
return 0, os.ErrClosed
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// Now block waiting for data
|
|
||||||
select {
|
|
||||||
case res := <-d.readCh:
|
|
||||||
if res.err != nil {
|
|
||||||
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
|
||||||
return 0, res.err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy packets from result to provided buffers
|
// Wait for a device to be available
|
||||||
count := 0
|
dev := d.peekLast()
|
||||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
if dev == nil {
|
||||||
// Handle offset mismatch if necessary
|
if !d.waitForDevice() {
|
||||||
// We assume the pump used defaultOffset (16)
|
return 0, io.EOF
|
||||||
// If caller asks for different offset, we need to shift
|
|
||||||
src := res.bufs[i]
|
|
||||||
srcOffset := res.offset
|
|
||||||
srcSize := res.sizes[i]
|
|
||||||
|
|
||||||
// Calculate where the packet data starts and ends in src
|
|
||||||
pktData := src[srcOffset : srcOffset+srcSize]
|
|
||||||
|
|
||||||
// Ensure dest buffer is large enough
|
|
||||||
if len(bufs[i]) < offset+len(pktData) {
|
|
||||||
continue // Skip if buffer too small
|
|
||||||
}
|
}
|
||||||
|
|
||||||
copy(bufs[i][offset:], pktData)
|
|
||||||
sizes[i] = len(pktData)
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
n = count
|
|
||||||
|
|
||||||
case pkt := <-d.injectCh:
|
|
||||||
if len(bufs) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
if len(bufs[0]) < offset+len(pkt) {
|
|
||||||
return 0, nil // Buffer too small
|
|
||||||
}
|
|
||||||
copy(bufs[0][offset:], pkt)
|
|
||||||
sizes[0] = len(pkt)
|
|
||||||
n = 1
|
|
||||||
|
|
||||||
case <-d.closed:
|
|
||||||
logger.Debug("MiddleDevice: Read returning os.ErrClosed")
|
|
||||||
return 0, os.ErrClosed // Signal that device is closed
|
|
||||||
}
|
|
||||||
|
|
||||||
d.mutex.RLock()
|
|
||||||
rules := d.rules
|
|
||||||
d.mutex.RUnlock()
|
|
||||||
|
|
||||||
if len(rules) == 0 {
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process packets and filter out handled ones
|
|
||||||
writeIdx := 0
|
|
||||||
for readIdx := 0; readIdx < n; readIdx++ {
|
|
||||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
|
||||||
|
|
||||||
destIP, ok := extractDestIP(packet)
|
|
||||||
if !ok {
|
|
||||||
// Can't parse, keep packet
|
|
||||||
if writeIdx != readIdx {
|
|
||||||
bufs[writeIdx] = bufs[readIdx]
|
|
||||||
sizes[writeIdx] = sizes[readIdx]
|
|
||||||
}
|
|
||||||
writeIdx++
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if packet matches any rule
|
// Now block waiting for data from readCh or injectCh
|
||||||
handled := false
|
select {
|
||||||
for _, rule := range rules {
|
case res := <-d.readCh:
|
||||||
if rule.DestIP == destIP {
|
if res.err != nil {
|
||||||
if rule.Handler(packet) {
|
// Check if device was swapped
|
||||||
// Packet was handled and should be dropped
|
if dev.IsClosed() {
|
||||||
handled = true
|
time.Sleep(1 * time.Millisecond)
|
||||||
break
|
continue
|
||||||
|
}
|
||||||
|
logger.Debug("MiddleDevice: Read returning error from pump: %v", res.err)
|
||||||
|
return 0, res.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy packets from result to provided buffers
|
||||||
|
count := 0
|
||||||
|
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||||
|
src := res.bufs[i]
|
||||||
|
srcOffset := res.offset
|
||||||
|
srcSize := res.sizes[i]
|
||||||
|
|
||||||
|
pktData := src[srcOffset : srcOffset+srcSize]
|
||||||
|
|
||||||
|
if len(bufs[i]) < offset+len(pktData) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(bufs[i][offset:], pktData)
|
||||||
|
sizes[i] = len(pktData)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
n = count
|
||||||
|
|
||||||
|
case pkt := <-d.injectCh:
|
||||||
|
if len(bufs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if len(bufs[0]) < offset+len(pkt) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
copy(bufs[0][offset:], pkt)
|
||||||
|
sizes[0] = len(pkt)
|
||||||
|
n = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply filtering rules
|
||||||
|
d.rulesMutex.RLock()
|
||||||
|
rules := d.rules
|
||||||
|
d.rulesMutex.RUnlock()
|
||||||
|
|
||||||
|
if len(rules) == 0 {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process packets and filter out handled ones
|
||||||
|
writeIdx := 0
|
||||||
|
for readIdx := 0; readIdx < n; readIdx++ {
|
||||||
|
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||||
|
|
||||||
|
destIP, ok := extractDestIP(packet)
|
||||||
|
if !ok {
|
||||||
|
if writeIdx != readIdx {
|
||||||
|
bufs[writeIdx] = bufs[readIdx]
|
||||||
|
sizes[writeIdx] = sizes[readIdx]
|
||||||
|
}
|
||||||
|
writeIdx++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handled := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.DestIP == destIP {
|
||||||
|
if rule.Handler(packet) {
|
||||||
|
handled = true
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if !handled {
|
if !handled {
|
||||||
// Keep packet
|
if writeIdx != readIdx {
|
||||||
if writeIdx != readIdx {
|
bufs[writeIdx] = bufs[readIdx]
|
||||||
bufs[writeIdx] = bufs[readIdx]
|
sizes[writeIdx] = sizes[readIdx]
|
||||||
sizes[writeIdx] = sizes[readIdx]
|
}
|
||||||
|
writeIdx++
|
||||||
}
|
}
|
||||||
writeIdx++
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return writeIdx, err
|
return writeIdx, nil
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
d.mutex.RLock()
|
for {
|
||||||
rules := d.rules
|
if d.closed.Load() {
|
||||||
d.mutex.RUnlock()
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
if len(rules) == 0 {
|
dev := d.peekLast()
|
||||||
return d.Device.Write(bufs, offset)
|
if dev == nil {
|
||||||
}
|
if !d.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
// Filter packets going down
|
}
|
||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
|
||||||
for _, buf := range bufs {
|
|
||||||
if len(buf) <= offset {
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
packet := buf[offset:]
|
d.rulesMutex.RLock()
|
||||||
destIP, ok := extractDestIP(packet)
|
rules := d.rules
|
||||||
if !ok {
|
d.rulesMutex.RUnlock()
|
||||||
// Can't parse, keep packet
|
|
||||||
filteredBufs = append(filteredBufs, buf)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if packet matches any rule
|
var filteredBufs [][]byte
|
||||||
handled := false
|
if len(rules) == 0 {
|
||||||
for _, rule := range rules {
|
filteredBufs = bufs
|
||||||
if rule.DestIP == destIP {
|
} else {
|
||||||
if rule.Handler(packet) {
|
filteredBufs = make([][]byte, 0, len(bufs))
|
||||||
// Packet was handled and should be dropped
|
for _, buf := range bufs {
|
||||||
handled = true
|
if len(buf) <= offset {
|
||||||
break
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := buf[offset:]
|
||||||
|
destIP, ok := extractDestIP(packet)
|
||||||
|
if !ok {
|
||||||
|
filteredBufs = append(filteredBufs, buf)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
handled := false
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.DestIP == destIP {
|
||||||
|
if rule.Handler(packet) {
|
||||||
|
handled = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !handled {
|
||||||
|
filteredBufs = append(filteredBufs, buf)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !handled {
|
if len(filteredBufs) == 0 {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
return len(bufs), nil
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(filteredBufs) == 0 {
|
n, err := dev.Write(filteredBufs, offset)
|
||||||
return len(bufs), nil // All packets were handled
|
if err == nil {
|
||||||
}
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
return d.Device.Write(filteredBufs, offset)
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *MiddleDevice) waitForDevice() bool {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
for len(d.devices) == 0 && !d.closed.Load() {
|
||||||
|
d.cond.Wait()
|
||||||
|
}
|
||||||
|
return !d.closed.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MiddleDevice) peekLast() *closeAwareDevice {
|
||||||
|
d.mu.Lock()
|
||||||
|
defer d.mu.Unlock()
|
||||||
|
|
||||||
|
if len(d.devices) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return d.devices[len(d.devices)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteToTun writes packets directly to the underlying TUN device,
|
||||||
|
// bypassing WireGuard. This is useful for sending packets that should
|
||||||
|
// appear to come from the TUN interface (e.g., DNS responses from a proxy).
|
||||||
|
// Unlike Write(), this does not go through packet filtering rules.
|
||||||
|
func (d *MiddleDevice) WriteToTun(bufs [][]byte, offset int) (int, error) {
|
||||||
|
for {
|
||||||
|
if d.closed.Load() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
dev := d.peekLast()
|
||||||
|
if dev == nil {
|
||||||
|
if !d.waitForDevice() {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := dev.Write(bufs, offset)
|
||||||
|
if err == nil {
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if dev.IsClosed() {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ package device
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
@@ -13,6 +14,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
func CreateTUNFromFD(tunFd uint32, mtuInt int) (tun.Device, error) {
|
||||||
|
if runtime.GOOS == "android" { // otherwise we get a permission denied
|
||||||
|
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(tunFd))
|
||||||
|
return theTun, err
|
||||||
|
}
|
||||||
|
|
||||||
dupTunFd, err := unix.Dup(int(tunFd))
|
dupTunFd, err := unix.Dup(int(tunFd))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Unable to dup tun fd: %v", err)
|
logger.Error("Unable to dup tun fd: %v", err)
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/olm/device"
|
"github.com/fosrl/olm/device"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
@@ -36,8 +35,7 @@ type DNSProxy struct {
|
|||||||
upstreamDNS []string
|
upstreamDNS []string
|
||||||
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
tunnelDNS bool // Whether to tunnel DNS queries over WireGuard or to spit them out locally
|
||||||
mtu int
|
mtu int
|
||||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering and TUN writes
|
||||||
middleDevice *device.MiddleDevice // Reference to MiddleDevice for packet filtering
|
|
||||||
recordStore *DNSRecordStore // Local DNS records
|
recordStore *DNSRecordStore // Local DNS records
|
||||||
|
|
||||||
// Tunnel DNS fields - for sending queries over WireGuard
|
// Tunnel DNS fields - for sending queries over WireGuard
|
||||||
@@ -53,7 +51,7 @@ type DNSProxy struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewDNSProxy creates a new DNS proxy
|
// NewDNSProxy creates a new DNS proxy
|
||||||
func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
func NewDNSProxy(middleDevice *device.MiddleDevice, mtu int, utilitySubnet string, upstreamDns []string, tunnelDns bool, tunnelIP string) (*DNSProxy, error) {
|
||||||
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
proxyIP, err := PickIPFromSubnet(utilitySubnet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
return nil, fmt.Errorf("failed to pick DNS proxy IP from subnet: %v", err)
|
||||||
@@ -68,7 +66,6 @@ func NewDNSProxy(tunDevice tun.Device, middleDevice *device.MiddleDevice, mtu in
|
|||||||
proxy := &DNSProxy{
|
proxy := &DNSProxy{
|
||||||
proxyIP: proxyIP,
|
proxyIP: proxyIP,
|
||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
tunDevice: tunDevice,
|
|
||||||
middleDevice: middleDevice,
|
middleDevice: middleDevice,
|
||||||
upstreamDNS: upstreamDns,
|
upstreamDNS: upstreamDns,
|
||||||
tunnelDNS: tunnelDns,
|
tunnelDNS: tunnelDns,
|
||||||
@@ -694,9 +691,9 @@ func (p *DNSProxy) runPacketSender() {
|
|||||||
pos += len(slice)
|
pos += len(slice)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write packet to TUN device
|
// Write packet to TUN device via MiddleDevice
|
||||||
// offset=16 indicates packet data starts at position 16 in the buffer
|
// offset=16 indicates packet data starts at position 16 in the buffer
|
||||||
_, err := p.tunDevice.Write([][]byte{buf}, offset)
|
_, err := p.middleDevice.WriteToTun([][]byte{buf}, offset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
53
olm/olm.go
53
olm/olm.go
@@ -35,6 +35,7 @@ var (
|
|||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
tdev tun.Device
|
tdev tun.Device
|
||||||
middleDev *olmDevice.MiddleDevice
|
middleDev *olmDevice.MiddleDevice
|
||||||
|
interfaceName string
|
||||||
dnsProxy *dns.DNSProxy
|
dnsProxy *dns.DNSProxy
|
||||||
apiServer *api.API
|
apiServer *api.API
|
||||||
olmClient *websocket.Client
|
olmClient *websocket.Client
|
||||||
@@ -237,11 +238,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
stopPing = make(chan struct{})
|
stopPing = make(chan struct{})
|
||||||
|
|
||||||
var (
|
var (
|
||||||
interfaceName = config.InterfaceName
|
id = config.ID
|
||||||
id = config.ID
|
secret = config.Secret
|
||||||
secret = config.Secret
|
userToken = config.UserToken
|
||||||
userToken = config.UserToken
|
|
||||||
)
|
)
|
||||||
|
interfaceName = config.InterfaceName
|
||||||
|
|
||||||
apiServer.SetOrgID(config.OrgID)
|
apiServer.SetOrgID(config.OrgID)
|
||||||
|
|
||||||
@@ -307,12 +308,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
tdev, err = func() (tun.Device, error) {
|
tdev, err = func() (tun.Device, error) {
|
||||||
if config.FileDescriptorTun != 0 {
|
if config.FileDescriptorTun != 0 {
|
||||||
if runtime.GOOS == "android" { // otherwise we get a permission denied
|
return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU)
|
||||||
theTun, _, err := tun.CreateUnmonitoredTUNFromFD(int(config.FileDescriptorTun))
|
|
||||||
return theTun, err
|
|
||||||
} else {
|
|
||||||
return olmDevice.CreateTUNFromFD(config.FileDescriptorTun, config.MTU)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
var ifName = interfaceName
|
var ifName = interfaceName
|
||||||
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
if runtime.GOOS == "darwin" { // this is if we dont pass a fd
|
||||||
@@ -329,11 +325,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.FileDescriptorTun == 0 {
|
// if config.FileDescriptorTun == 0 {
|
||||||
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||||
interfaceName = realInterfaceName
|
interfaceName = realInterfaceName
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
// }
|
||||||
|
|
||||||
// Wrap TUN device with packet filter for DNS proxy
|
// Wrap TUN device with packet filter for DNS proxy
|
||||||
middleDev = olmDevice.NewMiddleDevice(tdev)
|
middleDev = olmDevice.NewMiddleDevice(tdev)
|
||||||
@@ -389,7 +385,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create and start DNS proxy
|
// Create and start DNS proxy
|
||||||
dnsProxy, err = dns.NewDNSProxy(tdev, middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP)
|
dnsProxy, err = dns.NewDNSProxy(middleDev, config.MTU, wgData.UtilitySubnet, config.UpstreamDNS, config.TunnelDNS, interfaceIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to create DNS proxy: %v", err)
|
logger.Error("Failed to create DNS proxy: %v", err)
|
||||||
}
|
}
|
||||||
@@ -956,6 +952,33 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
logger.Info("Tunnel process context cancelled, cleaning up")
|
logger.Info("Tunnel process context cancelled, cleaning up")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AddDevice(fd uint32) {
|
||||||
|
if middleDev == nil {
|
||||||
|
logger.Error("MiddleDevice is nil, cannot add device")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tunnelConfig.MTU == 0 {
|
||||||
|
logger.Error("No MTU configured, cannot create device")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tdev, err := olmDevice.CreateTUNFromFD(fd, tunnelConfig.MTU)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Failed to create TUN device: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// if config.FileDescriptorTun == 0 {
|
||||||
|
if realInterfaceName, err2 := tdev.Name(); err2 == nil { // if the interface is defined then this should not really do anything?
|
||||||
|
interfaceName = realInterfaceName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Here we replace the existing TUN device in the middle device with the new one
|
||||||
|
middleDev.AddDevice(tdev)
|
||||||
|
}
|
||||||
|
|
||||||
func Close() {
|
func Close() {
|
||||||
// Restore original DNS configuration
|
// Restore original DNS configuration
|
||||||
// we do this first to avoid any DNS issues if something else gets stuck
|
// we do this first to avoid any DNS issues if something else gets stuck
|
||||||
|
|||||||
Reference in New Issue
Block a user