Reorg the files

Former-commit-id: 5505c1d2c7
This commit is contained in:
Owen
2025-11-21 16:53:54 -05:00
parent 7941479994
commit d7cd746cc9
4 changed files with 35 additions and 76 deletions

View File

@@ -1,7 +1,6 @@
package olm
package device
import (
"encoding/binary"
"net/netip"
"sync"
@@ -17,23 +16,23 @@ type FilterRule struct {
Handler PacketHandler
}
// FilteredDevice wraps a TUN device with packet filtering capabilities
type FilteredDevice struct {
// MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct {
tun.Device
rules []FilterRule
mutex sync.RWMutex
}
// NewFilteredDevice creates a new filtered TUN device wrapper
func NewFilteredDevice(device tun.Device) *FilteredDevice {
return &FilteredDevice{
// NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice {
return &MiddleDevice{
Device: device,
rules: make([]FilterRule, 0),
}
}
// AddRule adds a packet filtering rule
func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rules = append(d.rules, FilterRule{
@@ -43,7 +42,7 @@ func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
}
// RemoveRule removes all rules for a given destination IP
func (d *FilteredDevice) RemoveRule(destIP netip.Addr) {
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock()
defer d.mutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules))
@@ -86,7 +85,7 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
}
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
n, err = d.Device.Read(bufs, sizes, offset)
if err != nil || n == 0 {
return n, err
@@ -142,7 +141,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
}
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
@@ -189,49 +188,3 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
return d.Device.Write(filteredBufs, offset)
}
// GetProtocol returns protocol number from IPv4 packet (fast path)
func GetProtocol(packet []byte) (uint8, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
if version == 4 {
return packet[9], true
} else if version == 6 {
if len(packet) < 40 {
return 0, false
}
return packet[6], true
}
return 0, false
}
// GetDestPort returns destination port from TCP/UDP packet (fast path)
func GetDestPort(packet []byte) (uint16, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
var headerLen int
if version == 4 {
ihl := packet[0] & 0x0F
headerLen = int(ihl) * 4
if len(packet) < headerLen+4 {
return 0, false
}
} else if version == 6 {
headerLen = 40
if len(packet) < headerLen+4 {
return 0, false
}
} else {
return 0, false
}
// Destination port is at bytes 2-3 of TCP/UDP header
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
return port, true
}

View File

@@ -1,8 +1,10 @@
package olm
package device
import (
"net/netip"
"testing"
"github.com/fosrl/newt/util"
)
func TestExtractDestIP(t *testing.T) {
@@ -74,7 +76,7 @@ func TestGetProtocol(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotProto, gotOk := GetProtocol(tt.packet)
gotProto, gotOk := util.GetProtocol(tt.packet)
if gotOk != tt.wantOk {
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
return

View File

@@ -1,4 +1,4 @@
package olm
package dns
import (
"context"
@@ -9,6 +9,8 @@ import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -96,9 +98,9 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
}
// Start starts the DNS proxy and registers with the filter
func (p *DNSProxy) Start(filter *FilteredDevice) error {
func (p *DNSProxy) Start(device *device.MiddleDevice) error {
// Install packet filter rule
filter.AddRule(p.proxyIP, p.handlePacket)
device.AddRule(p.proxyIP, p.handlePacket)
// Start DNS listener
p.wg.Add(2)
@@ -110,9 +112,9 @@ func (p *DNSProxy) Start(filter *FilteredDevice) error {
}
// Stop stops the DNS proxy
func (p *DNSProxy) Stop(filter *FilteredDevice) {
if filter != nil {
filter.RemoveRule(p.proxyIP)
func (p *DNSProxy) Stop(device *device.MiddleDevice) {
if device != nil {
device.RemoveRule(p.proxyIP)
}
p.cancel()
p.wg.Wait()
@@ -134,12 +136,12 @@ func (p *DNSProxy) handlePacket(packet []byte) bool {
}
// Quick check for UDP port 53
proto, ok := GetProtocol(packet)
proto, ok := util.GetProtocol(packet)
if !ok || proto != 17 { // 17 = UDP
return false // Not UDP, don't handle
}
port, ok := GetDestPort(packet)
port, ok := util.GetDestPort(packet)
if !ok || port != DNSPort {
return false // Not DNS port
}

View File

@@ -13,6 +13,8 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/api"
middleDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
"github.com/fosrl/olm/network"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket"
@@ -70,8 +72,8 @@ var (
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
filteredDev *FilteredDevice
dnsProxy *DNSProxy
middleDev *middleDevice.MiddleDevice
dnsProxy *dns.DNSProxy
apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
@@ -427,15 +429,15 @@ func StartTunnel(config TunnelConfig) {
}
// Wrap TUN device with packet filter for DNS proxy
filteredDev = NewFilteredDevice(tdev)
middleDev = middleDevice.NewMiddleDevice(tdev)
// Create and start DNS proxy
dnsProxy, err = NewDNSProxy(tdev, config.MTU)
dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
return
}
if err := dnsProxy.Start(filteredDev); err != nil {
if err := dnsProxy.Start(middleDev); err != nil {
logger.Error("Failed to start DNS proxy: %v", err)
return
}
@@ -458,7 +460,7 @@ func StartTunnel(config TunnelConfig) {
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
// Use filtered device instead of raw TUN device
dev = device.NewDevice(filteredDev, sharedBind, (*device.Logger)(wgLogger))
dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger))
// uapiListener, err = uapiListen(interfaceName, fileUAPI)
// if err != nil {
@@ -1067,13 +1069,13 @@ func Close() {
// Stop DNS proxy
if dnsProxy != nil {
dnsProxy.Stop(filteredDev)
dnsProxy.Stop(middleDev)
dnsProxy = nil
}
// Clear filtered device
if filteredDev != nil {
filteredDev = nil
if middleDev != nil {
middleDev = nil
}
// Close TUN device