diff --git a/olm/device_filter.go b/device/middle_device.go similarity index 69% rename from olm/device_filter.go rename to device/middle_device.go index fcd23db..82c13ac 100644 --- a/olm/device_filter.go +++ b/device/middle_device.go @@ -1,7 +1,6 @@ -package olm +package device import ( - "encoding/binary" "net/netip" "sync" @@ -17,23 +16,23 @@ type FilterRule struct { Handler PacketHandler } -// FilteredDevice wraps a TUN device with packet filtering capabilities -type FilteredDevice struct { +// MiddleDevice wraps a TUN device with packet filtering capabilities +type MiddleDevice struct { tun.Device rules []FilterRule mutex sync.RWMutex } -// NewFilteredDevice creates a new filtered TUN device wrapper -func NewFilteredDevice(device tun.Device) *FilteredDevice { - return &FilteredDevice{ +// NewMiddleDevice creates a new filtered TUN device wrapper +func NewMiddleDevice(device tun.Device) *MiddleDevice { + return &MiddleDevice{ Device: device, rules: make([]FilterRule, 0), } } // AddRule adds a packet filtering rule -func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) { +func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) { d.mutex.Lock() defer d.mutex.Unlock() d.rules = append(d.rules, FilterRule{ @@ -43,7 +42,7 @@ func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) { } // RemoveRule removes all rules for a given destination IP -func (d *FilteredDevice) RemoveRule(destIP netip.Addr) { +func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { d.mutex.Lock() defer d.mutex.Unlock() newRules := make([]FilterRule, 0, len(d.rules)) @@ -86,7 +85,7 @@ func extractDestIP(packet []byte) (netip.Addr, bool) { } // Read intercepts packets going UP from the TUN device (towards WireGuard) -func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { n, err = d.Device.Read(bufs, sizes, offset) if err != nil || n == 0 { return n, err @@ -142,7 +141,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er } // Write intercepts packets going DOWN to the TUN device (from WireGuard) -func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { +func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) { d.mutex.RLock() rules := d.rules d.mutex.RUnlock() @@ -189,49 +188,3 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { return d.Device.Write(filteredBufs, offset) } - -// GetProtocol returns protocol number from IPv4 packet (fast path) -func GetProtocol(packet []byte) (uint8, bool) { - if len(packet) < 20 { - return 0, false - } - version := packet[0] >> 4 - if version == 4 { - return packet[9], true - } else if version == 6 { - if len(packet) < 40 { - return 0, false - } - return packet[6], true - } - return 0, false -} - -// GetDestPort returns destination port from TCP/UDP packet (fast path) -func GetDestPort(packet []byte) (uint16, bool) { - if len(packet) < 20 { - return 0, false - } - - version := packet[0] >> 4 - var headerLen int - - if version == 4 { - ihl := packet[0] & 0x0F - headerLen = int(ihl) * 4 - if len(packet) < headerLen+4 { - return 0, false - } - } else if version == 6 { - headerLen = 40 - if len(packet) < headerLen+4 { - return 0, false - } - } else { - return 0, false - } - - // Destination port is at bytes 2-3 of TCP/UDP header - port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4]) - return port, true -} diff --git a/olm/device_filter_test.go b/device/middle_device_test.go similarity index 95% rename from olm/device_filter_test.go rename to device/middle_device_test.go index 39a5f07..58cb88f 100644 --- a/olm/device_filter_test.go +++ b/device/middle_device_test.go @@ -1,8 +1,10 @@ -package olm +package device import ( "net/netip" "testing" + + "github.com/fosrl/newt/util" ) func TestExtractDestIP(t *testing.T) { @@ -74,7 +76,7 @@ func TestGetProtocol(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotProto, gotOk := GetProtocol(tt.packet) + gotProto, gotOk := util.GetProtocol(tt.packet) if gotOk != tt.wantOk { t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk) return diff --git a/olm/dns_proxy.go b/dns/dns_proxy.go similarity index 95% rename from olm/dns_proxy.go rename to dns/dns_proxy.go index 24e30a9..6ae7488 100644 --- a/olm/dns_proxy.go +++ b/dns/dns_proxy.go @@ -1,4 +1,4 @@ -package olm +package dns import ( "context" @@ -9,6 +9,8 @@ import ( "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + "github.com/fosrl/olm/device" "golang.zx2c4.com/wireguard/tun" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" @@ -96,9 +98,9 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) { } // Start starts the DNS proxy and registers with the filter -func (p *DNSProxy) Start(filter *FilteredDevice) error { +func (p *DNSProxy) Start(device *device.MiddleDevice) error { // Install packet filter rule - filter.AddRule(p.proxyIP, p.handlePacket) + device.AddRule(p.proxyIP, p.handlePacket) // Start DNS listener p.wg.Add(2) @@ -110,9 +112,9 @@ func (p *DNSProxy) Start(filter *FilteredDevice) error { } // Stop stops the DNS proxy -func (p *DNSProxy) Stop(filter *FilteredDevice) { - if filter != nil { - filter.RemoveRule(p.proxyIP) +func (p *DNSProxy) Stop(device *device.MiddleDevice) { + if device != nil { + device.RemoveRule(p.proxyIP) } p.cancel() p.wg.Wait() @@ -134,12 +136,12 @@ func (p *DNSProxy) handlePacket(packet []byte) bool { } // Quick check for UDP port 53 - proto, ok := GetProtocol(packet) + proto, ok := util.GetProtocol(packet) if !ok || proto != 17 { // 17 = UDP return false // Not UDP, don't handle } - port, ok := GetDestPort(packet) + port, ok := util.GetDestPort(packet) if !ok || port != DNSPort { return false // Not DNS port } diff --git a/olm/olm.go b/olm/olm.go index 4cfef4d..bc6f828 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -13,6 +13,8 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" "github.com/fosrl/olm/api" + middleDevice "github.com/fosrl/olm/device" + "github.com/fosrl/olm/dns" "github.com/fosrl/olm/network" "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" @@ -70,8 +72,8 @@ var ( holePunchData HolePunchData uapiListener net.Listener tdev tun.Device - filteredDev *FilteredDevice - dnsProxy *DNSProxy + middleDev *middleDevice.MiddleDevice + dnsProxy *dns.DNSProxy apiServer *api.API olmClient *websocket.Client tunnelCancel context.CancelFunc @@ -427,15 +429,15 @@ func StartTunnel(config TunnelConfig) { } // Wrap TUN device with packet filter for DNS proxy - filteredDev = NewFilteredDevice(tdev) + middleDev = middleDevice.NewMiddleDevice(tdev) // Create and start DNS proxy - dnsProxy, err = NewDNSProxy(tdev, config.MTU) + dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU) if err != nil { logger.Error("Failed to create DNS proxy: %v", err) return } - if err := dnsProxy.Start(filteredDev); err != nil { + if err := dnsProxy.Start(middleDev); err != nil { logger.Error("Failed to start DNS proxy: %v", err) return } @@ -458,7 +460,7 @@ func StartTunnel(config TunnelConfig) { wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") // Use filtered device instead of raw TUN device - dev = device.NewDevice(filteredDev, sharedBind, (*device.Logger)(wgLogger)) + dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger)) // uapiListener, err = uapiListen(interfaceName, fileUAPI) // if err != nil { @@ -1067,13 +1069,13 @@ func Close() { // Stop DNS proxy if dnsProxy != nil { - dnsProxy.Stop(filteredDev) + dnsProxy.Stop(middleDev) dnsProxy = nil } // Clear filtered device - if filteredDev != nil { - filteredDev = nil + if middleDev != nil { + middleDev = nil } // Close TUN device