mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
@@ -1,7 +1,6 @@
|
||||
package olm
|
||||
package device
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
@@ -17,23 +16,23 @@ type FilterRule struct {
|
||||
Handler PacketHandler
|
||||
}
|
||||
|
||||
// FilteredDevice wraps a TUN device with packet filtering capabilities
|
||||
type FilteredDevice struct {
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
type MiddleDevice struct {
|
||||
tun.Device
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFilteredDevice creates a new filtered TUN device wrapper
|
||||
func NewFilteredDevice(device tun.Device) *FilteredDevice {
|
||||
return &FilteredDevice{
|
||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
return &MiddleDevice{
|
||||
Device: device,
|
||||
rules: make([]FilterRule, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule adds a packet filtering rule
|
||||
func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rules = append(d.rules, FilterRule{
|
||||
@@ -43,7 +42,7 @@ func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
}
|
||||
|
||||
// RemoveRule removes all rules for a given destination IP
|
||||
func (d *FilteredDevice) RemoveRule(destIP netip.Addr) {
|
||||
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
newRules := make([]FilterRule, 0, len(d.rules))
|
||||
@@ -86,7 +85,7 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
n, err = d.Device.Read(bufs, sizes, offset)
|
||||
if err != nil || n == 0 {
|
||||
return n, err
|
||||
@@ -142,7 +141,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
||||
}
|
||||
|
||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
@@ -189,49 +188,3 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
|
||||
return d.Device.Write(filteredBufs, offset)
|
||||
}
|
||||
|
||||
// GetProtocol returns protocol number from IPv4 packet (fast path)
|
||||
func GetProtocol(packet []byte) (uint8, bool) {
|
||||
if len(packet) < 20 {
|
||||
return 0, false
|
||||
}
|
||||
version := packet[0] >> 4
|
||||
if version == 4 {
|
||||
return packet[9], true
|
||||
} else if version == 6 {
|
||||
if len(packet) < 40 {
|
||||
return 0, false
|
||||
}
|
||||
return packet[6], true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetDestPort returns destination port from TCP/UDP packet (fast path)
|
||||
func GetDestPort(packet []byte) (uint16, bool) {
|
||||
if len(packet) < 20 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
var headerLen int
|
||||
|
||||
if version == 4 {
|
||||
ihl := packet[0] & 0x0F
|
||||
headerLen = int(ihl) * 4
|
||||
if len(packet) < headerLen+4 {
|
||||
return 0, false
|
||||
}
|
||||
} else if version == 6 {
|
||||
headerLen = 40
|
||||
if len(packet) < headerLen+4 {
|
||||
return 0, false
|
||||
}
|
||||
} else {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Destination port is at bytes 2-3 of TCP/UDP header
|
||||
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
|
||||
return port, true
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package olm
|
||||
package device
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/fosrl/newt/util"
|
||||
)
|
||||
|
||||
func TestExtractDestIP(t *testing.T) {
|
||||
@@ -74,7 +76,7 @@ func TestGetProtocol(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotProto, gotOk := GetProtocol(tt.packet)
|
||||
gotProto, gotOk := util.GetProtocol(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
@@ -1,4 +1,4 @@
|
||||
package olm
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -96,9 +98,9 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
|
||||
}
|
||||
|
||||
// Start starts the DNS proxy and registers with the filter
|
||||
func (p *DNSProxy) Start(filter *FilteredDevice) error {
|
||||
func (p *DNSProxy) Start(device *device.MiddleDevice) error {
|
||||
// Install packet filter rule
|
||||
filter.AddRule(p.proxyIP, p.handlePacket)
|
||||
device.AddRule(p.proxyIP, p.handlePacket)
|
||||
|
||||
// Start DNS listener
|
||||
p.wg.Add(2)
|
||||
@@ -110,9 +112,9 @@ func (p *DNSProxy) Start(filter *FilteredDevice) error {
|
||||
}
|
||||
|
||||
// Stop stops the DNS proxy
|
||||
func (p *DNSProxy) Stop(filter *FilteredDevice) {
|
||||
if filter != nil {
|
||||
filter.RemoveRule(p.proxyIP)
|
||||
func (p *DNSProxy) Stop(device *device.MiddleDevice) {
|
||||
if device != nil {
|
||||
device.RemoveRule(p.proxyIP)
|
||||
}
|
||||
p.cancel()
|
||||
p.wg.Wait()
|
||||
@@ -134,12 +136,12 @@ func (p *DNSProxy) handlePacket(packet []byte) bool {
|
||||
}
|
||||
|
||||
// Quick check for UDP port 53
|
||||
proto, ok := GetProtocol(packet)
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // 17 = UDP
|
||||
return false // Not UDP, don't handle
|
||||
}
|
||||
|
||||
port, ok := GetDestPort(packet)
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok || port != DNSPort {
|
||||
return false // Not DNS port
|
||||
}
|
||||
20
olm/olm.go
20
olm/olm.go
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
"github.com/fosrl/olm/api"
|
||||
middleDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/dns"
|
||||
"github.com/fosrl/olm/network"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
@@ -70,8 +72,8 @@ var (
|
||||
holePunchData HolePunchData
|
||||
uapiListener net.Listener
|
||||
tdev tun.Device
|
||||
filteredDev *FilteredDevice
|
||||
dnsProxy *DNSProxy
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
dnsProxy *dns.DNSProxy
|
||||
apiServer *api.API
|
||||
olmClient *websocket.Client
|
||||
tunnelCancel context.CancelFunc
|
||||
@@ -427,15 +429,15 @@ func StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
|
||||
// Wrap TUN device with packet filter for DNS proxy
|
||||
filteredDev = NewFilteredDevice(tdev)
|
||||
middleDev = middleDevice.NewMiddleDevice(tdev)
|
||||
|
||||
// Create and start DNS proxy
|
||||
dnsProxy, err = NewDNSProxy(tdev, config.MTU)
|
||||
dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS proxy: %v", err)
|
||||
return
|
||||
}
|
||||
if err := dnsProxy.Start(filteredDev); err != nil {
|
||||
if err := dnsProxy.Start(middleDev); err != nil {
|
||||
logger.Error("Failed to start DNS proxy: %v", err)
|
||||
return
|
||||
}
|
||||
@@ -458,7 +460,7 @@ func StartTunnel(config TunnelConfig) {
|
||||
|
||||
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||
// Use filtered device instead of raw TUN device
|
||||
dev = device.NewDevice(filteredDev, sharedBind, (*device.Logger)(wgLogger))
|
||||
dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger))
|
||||
|
||||
// uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||
// if err != nil {
|
||||
@@ -1067,13 +1069,13 @@ func Close() {
|
||||
|
||||
// Stop DNS proxy
|
||||
if dnsProxy != nil {
|
||||
dnsProxy.Stop(filteredDev)
|
||||
dnsProxy.Stop(middleDev)
|
||||
dnsProxy = nil
|
||||
}
|
||||
|
||||
// Clear filtered device
|
||||
if filteredDev != nil {
|
||||
filteredDev = nil
|
||||
if middleDev != nil {
|
||||
middleDev = nil
|
||||
}
|
||||
|
||||
// Close TUN device
|
||||
|
||||
Reference in New Issue
Block a user