Reorg the files

Former-commit-id: 5505c1d2c7
This commit is contained in:
Owen
2025-11-21 16:53:54 -05:00
parent 7941479994
commit d7cd746cc9
4 changed files with 35 additions and 76 deletions

View File

@@ -1,7 +1,6 @@
package olm package device
import ( import (
"encoding/binary"
"net/netip" "net/netip"
"sync" "sync"
@@ -17,23 +16,23 @@ type FilterRule struct {
Handler PacketHandler Handler PacketHandler
} }
// FilteredDevice wraps a TUN device with packet filtering capabilities // MiddleDevice wraps a TUN device with packet filtering capabilities
type FilteredDevice struct { type MiddleDevice struct {
tun.Device tun.Device
rules []FilterRule rules []FilterRule
mutex sync.RWMutex mutex sync.RWMutex
} }
// NewFilteredDevice creates a new filtered TUN device wrapper // NewMiddleDevice creates a new filtered TUN device wrapper
func NewFilteredDevice(device tun.Device) *FilteredDevice { func NewMiddleDevice(device tun.Device) *MiddleDevice {
return &FilteredDevice{ return &MiddleDevice{
Device: device, Device: device,
rules: make([]FilterRule, 0), rules: make([]FilterRule, 0),
} }
} }
// AddRule adds a packet filtering rule // AddRule adds a packet filtering rule
func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) { func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
d.rules = append(d.rules, FilterRule{ d.rules = append(d.rules, FilterRule{
@@ -43,7 +42,7 @@ func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
} }
// RemoveRule removes all rules for a given destination IP // RemoveRule removes all rules for a given destination IP
func (d *FilteredDevice) RemoveRule(destIP netip.Addr) { func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock() d.mutex.Lock()
defer d.mutex.Unlock() defer d.mutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules)) newRules := make([]FilterRule, 0, len(d.rules))
@@ -86,7 +85,7 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
} }
// Read intercepts packets going UP from the TUN device (towards WireGuard) // Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
n, err = d.Device.Read(bufs, sizes, offset) n, err = d.Device.Read(bufs, sizes, offset)
if err != nil || n == 0 { if err != nil || n == 0 {
return n, err return n, err
@@ -142,7 +141,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
} }
// Write intercepts packets going DOWN to the TUN device (from WireGuard) // Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock() d.mutex.RLock()
rules := d.rules rules := d.rules
d.mutex.RUnlock() d.mutex.RUnlock()
@@ -189,49 +188,3 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
return d.Device.Write(filteredBufs, offset) return d.Device.Write(filteredBufs, offset)
} }
// GetProtocol returns protocol number from IPv4 packet (fast path)
func GetProtocol(packet []byte) (uint8, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
if version == 4 {
return packet[9], true
} else if version == 6 {
if len(packet) < 40 {
return 0, false
}
return packet[6], true
}
return 0, false
}
// GetDestPort returns destination port from TCP/UDP packet (fast path)
func GetDestPort(packet []byte) (uint16, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
var headerLen int
if version == 4 {
ihl := packet[0] & 0x0F
headerLen = int(ihl) * 4
if len(packet) < headerLen+4 {
return 0, false
}
} else if version == 6 {
headerLen = 40
if len(packet) < headerLen+4 {
return 0, false
}
} else {
return 0, false
}
// Destination port is at bytes 2-3 of TCP/UDP header
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
return port, true
}

View File

@@ -1,8 +1,10 @@
package olm package device
import ( import (
"net/netip" "net/netip"
"testing" "testing"
"github.com/fosrl/newt/util"
) )
func TestExtractDestIP(t *testing.T) { func TestExtractDestIP(t *testing.T) {
@@ -74,7 +76,7 @@ func TestGetProtocol(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
gotProto, gotOk := GetProtocol(tt.packet) gotProto, gotOk := util.GetProtocol(tt.packet)
if gotOk != tt.wantOk { if gotOk != tt.wantOk {
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk) t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
return return

View File

@@ -1,4 +1,4 @@
package olm package dns
import ( import (
"context" "context"
@@ -9,6 +9,8 @@ import (
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
"github.com/fosrl/olm/device"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
@@ -96,9 +98,9 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
} }
// Start starts the DNS proxy and registers with the filter // Start starts the DNS proxy and registers with the filter
func (p *DNSProxy) Start(filter *FilteredDevice) error { func (p *DNSProxy) Start(device *device.MiddleDevice) error {
// Install packet filter rule // Install packet filter rule
filter.AddRule(p.proxyIP, p.handlePacket) device.AddRule(p.proxyIP, p.handlePacket)
// Start DNS listener // Start DNS listener
p.wg.Add(2) p.wg.Add(2)
@@ -110,9 +112,9 @@ func (p *DNSProxy) Start(filter *FilteredDevice) error {
} }
// Stop stops the DNS proxy // Stop stops the DNS proxy
func (p *DNSProxy) Stop(filter *FilteredDevice) { func (p *DNSProxy) Stop(device *device.MiddleDevice) {
if filter != nil { if device != nil {
filter.RemoveRule(p.proxyIP) device.RemoveRule(p.proxyIP)
} }
p.cancel() p.cancel()
p.wg.Wait() p.wg.Wait()
@@ -134,12 +136,12 @@ func (p *DNSProxy) handlePacket(packet []byte) bool {
} }
// Quick check for UDP port 53 // Quick check for UDP port 53
proto, ok := GetProtocol(packet) proto, ok := util.GetProtocol(packet)
if !ok || proto != 17 { // 17 = UDP if !ok || proto != 17 { // 17 = UDP
return false // Not UDP, don't handle return false // Not UDP, don't handle
} }
port, ok := GetDestPort(packet) port, ok := util.GetDestPort(packet)
if !ok || port != DNSPort { if !ok || port != DNSPort {
return false // Not DNS port return false // Not DNS port
} }

View File

@@ -13,6 +13,8 @@ import (
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util" "github.com/fosrl/newt/util"
"github.com/fosrl/olm/api" "github.com/fosrl/olm/api"
middleDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/dns"
"github.com/fosrl/olm/network" "github.com/fosrl/olm/network"
"github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/websocket" "github.com/fosrl/olm/websocket"
@@ -70,8 +72,8 @@ var (
holePunchData HolePunchData holePunchData HolePunchData
uapiListener net.Listener uapiListener net.Listener
tdev tun.Device tdev tun.Device
filteredDev *FilteredDevice middleDev *middleDevice.MiddleDevice
dnsProxy *DNSProxy dnsProxy *dns.DNSProxy
apiServer *api.API apiServer *api.API
olmClient *websocket.Client olmClient *websocket.Client
tunnelCancel context.CancelFunc tunnelCancel context.CancelFunc
@@ -427,15 +429,15 @@ func StartTunnel(config TunnelConfig) {
} }
// Wrap TUN device with packet filter for DNS proxy // Wrap TUN device with packet filter for DNS proxy
filteredDev = NewFilteredDevice(tdev) middleDev = middleDevice.NewMiddleDevice(tdev)
// Create and start DNS proxy // Create and start DNS proxy
dnsProxy, err = NewDNSProxy(tdev, config.MTU) dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
if err != nil { if err != nil {
logger.Error("Failed to create DNS proxy: %v", err) logger.Error("Failed to create DNS proxy: %v", err)
return return
} }
if err := dnsProxy.Start(filteredDev); err != nil { if err := dnsProxy.Start(middleDev); err != nil {
logger.Error("Failed to start DNS proxy: %v", err) logger.Error("Failed to start DNS proxy: %v", err)
return return
} }
@@ -458,7 +460,7 @@ func StartTunnel(config TunnelConfig) {
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
// Use filtered device instead of raw TUN device // Use filtered device instead of raw TUN device
dev = device.NewDevice(filteredDev, sharedBind, (*device.Logger)(wgLogger)) dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger))
// uapiListener, err = uapiListen(interfaceName, fileUAPI) // uapiListener, err = uapiListen(interfaceName, fileUAPI)
// if err != nil { // if err != nil {
@@ -1067,13 +1069,13 @@ func Close() {
// Stop DNS proxy // Stop DNS proxy
if dnsProxy != nil { if dnsProxy != nil {
dnsProxy.Stop(filteredDev) dnsProxy.Stop(middleDev)
dnsProxy = nil dnsProxy = nil
} }
// Clear filtered device // Clear filtered device
if filteredDev != nil { if middleDev != nil {
filteredDev = nil middleDev = nil
} }
// Close TUN device // Close TUN device