mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
@@ -1,7 +1,6 @@
|
|||||||
package olm
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -17,23 +16,23 @@ type FilterRule struct {
|
|||||||
Handler PacketHandler
|
Handler PacketHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice wraps a TUN device with packet filtering capabilities
|
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||||
type FilteredDevice struct {
|
type MiddleDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
rules []FilterRule
|
rules []FilterRule
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFilteredDevice creates a new filtered TUN device wrapper
|
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||||
func NewFilteredDevice(device tun.Device) *FilteredDevice {
|
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||||
return &FilteredDevice{
|
return &MiddleDevice{
|
||||||
Device: device,
|
Device: device,
|
||||||
rules: make([]FilterRule, 0),
|
rules: make([]FilterRule, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRule adds a packet filtering rule
|
// AddRule adds a packet filtering rule
|
||||||
func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
d.rules = append(d.rules, FilterRule{
|
d.rules = append(d.rules, FilterRule{
|
||||||
@@ -43,7 +42,7 @@ func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveRule removes all rules for a given destination IP
|
// RemoveRule removes all rules for a given destination IP
|
||||||
func (d *FilteredDevice) RemoveRule(destIP netip.Addr) {
|
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
newRules := make([]FilterRule, 0, len(d.rules))
|
newRules := make([]FilterRule, 0, len(d.rules))
|
||||||
@@ -86,7 +85,7 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
n, err = d.Device.Read(bufs, sizes, offset)
|
n, err = d.Device.Read(bufs, sizes, offset)
|
||||||
if err != nil || n == 0 {
|
if err != nil || n == 0 {
|
||||||
return n, err
|
return n, err
|
||||||
@@ -142,7 +141,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||||
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
d.mutex.RLock()
|
d.mutex.RLock()
|
||||||
rules := d.rules
|
rules := d.rules
|
||||||
d.mutex.RUnlock()
|
d.mutex.RUnlock()
|
||||||
@@ -189,49 +188,3 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
|
|
||||||
return d.Device.Write(filteredBufs, offset)
|
return d.Device.Write(filteredBufs, offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetProtocol returns protocol number from IPv4 packet (fast path)
|
|
||||||
func GetProtocol(packet []byte) (uint8, bool) {
|
|
||||||
if len(packet) < 20 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
version := packet[0] >> 4
|
|
||||||
if version == 4 {
|
|
||||||
return packet[9], true
|
|
||||||
} else if version == 6 {
|
|
||||||
if len(packet) < 40 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
return packet[6], true
|
|
||||||
}
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetDestPort returns destination port from TCP/UDP packet (fast path)
|
|
||||||
func GetDestPort(packet []byte) (uint16, bool) {
|
|
||||||
if len(packet) < 20 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
version := packet[0] >> 4
|
|
||||||
var headerLen int
|
|
||||||
|
|
||||||
if version == 4 {
|
|
||||||
ihl := packet[0] & 0x0F
|
|
||||||
headerLen = int(ihl) * 4
|
|
||||||
if len(packet) < headerLen+4 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
} else if version == 6 {
|
|
||||||
headerLen = 40
|
|
||||||
if len(packet) < headerLen+4 {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return 0, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Destination port is at bytes 2-3 of TCP/UDP header
|
|
||||||
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
|
|
||||||
return port, true
|
|
||||||
}
|
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
package olm
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestExtractDestIP(t *testing.T) {
|
func TestExtractDestIP(t *testing.T) {
|
||||||
@@ -74,7 +76,7 @@ func TestGetProtocol(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
gotProto, gotOk := GetProtocol(tt.packet)
|
gotProto, gotOk := util.GetProtocol(tt.packet)
|
||||||
if gotOk != tt.wantOk {
|
if gotOk != tt.wantOk {
|
||||||
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
|
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
|
||||||
return
|
return
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package olm
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/util"
|
||||||
|
"github.com/fosrl/olm/device"
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
@@ -96,9 +98,9 @@ func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the DNS proxy and registers with the filter
|
// Start starts the DNS proxy and registers with the filter
|
||||||
func (p *DNSProxy) Start(filter *FilteredDevice) error {
|
func (p *DNSProxy) Start(device *device.MiddleDevice) error {
|
||||||
// Install packet filter rule
|
// Install packet filter rule
|
||||||
filter.AddRule(p.proxyIP, p.handlePacket)
|
device.AddRule(p.proxyIP, p.handlePacket)
|
||||||
|
|
||||||
// Start DNS listener
|
// Start DNS listener
|
||||||
p.wg.Add(2)
|
p.wg.Add(2)
|
||||||
@@ -110,9 +112,9 @@ func (p *DNSProxy) Start(filter *FilteredDevice) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Stop stops the DNS proxy
|
// Stop stops the DNS proxy
|
||||||
func (p *DNSProxy) Stop(filter *FilteredDevice) {
|
func (p *DNSProxy) Stop(device *device.MiddleDevice) {
|
||||||
if filter != nil {
|
if device != nil {
|
||||||
filter.RemoveRule(p.proxyIP)
|
device.RemoveRule(p.proxyIP)
|
||||||
}
|
}
|
||||||
p.cancel()
|
p.cancel()
|
||||||
p.wg.Wait()
|
p.wg.Wait()
|
||||||
@@ -134,12 +136,12 @@ func (p *DNSProxy) handlePacket(packet []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Quick check for UDP port 53
|
// Quick check for UDP port 53
|
||||||
proto, ok := GetProtocol(packet)
|
proto, ok := util.GetProtocol(packet)
|
||||||
if !ok || proto != 17 { // 17 = UDP
|
if !ok || proto != 17 { // 17 = UDP
|
||||||
return false // Not UDP, don't handle
|
return false // Not UDP, don't handle
|
||||||
}
|
}
|
||||||
|
|
||||||
port, ok := GetDestPort(packet)
|
port, ok := util.GetDestPort(packet)
|
||||||
if !ok || port != DNSPort {
|
if !ok || port != DNSPort {
|
||||||
return false // Not DNS port
|
return false // Not DNS port
|
||||||
}
|
}
|
||||||
20
olm/olm.go
20
olm/olm.go
@@ -13,6 +13,8 @@ import (
|
|||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/olm/api"
|
"github.com/fosrl/olm/api"
|
||||||
|
middleDevice "github.com/fosrl/olm/device"
|
||||||
|
"github.com/fosrl/olm/dns"
|
||||||
"github.com/fosrl/olm/network"
|
"github.com/fosrl/olm/network"
|
||||||
"github.com/fosrl/olm/peermonitor"
|
"github.com/fosrl/olm/peermonitor"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
@@ -70,8 +72,8 @@ var (
|
|||||||
holePunchData HolePunchData
|
holePunchData HolePunchData
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
tdev tun.Device
|
tdev tun.Device
|
||||||
filteredDev *FilteredDevice
|
middleDev *middleDevice.MiddleDevice
|
||||||
dnsProxy *DNSProxy
|
dnsProxy *dns.DNSProxy
|
||||||
apiServer *api.API
|
apiServer *api.API
|
||||||
olmClient *websocket.Client
|
olmClient *websocket.Client
|
||||||
tunnelCancel context.CancelFunc
|
tunnelCancel context.CancelFunc
|
||||||
@@ -427,15 +429,15 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Wrap TUN device with packet filter for DNS proxy
|
// Wrap TUN device with packet filter for DNS proxy
|
||||||
filteredDev = NewFilteredDevice(tdev)
|
middleDev = middleDevice.NewMiddleDevice(tdev)
|
||||||
|
|
||||||
// Create and start DNS proxy
|
// Create and start DNS proxy
|
||||||
dnsProxy, err = NewDNSProxy(tdev, config.MTU)
|
dnsProxy, err = dns.NewDNSProxy(tdev, config.MTU)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to create DNS proxy: %v", err)
|
logger.Error("Failed to create DNS proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := dnsProxy.Start(filteredDev); err != nil {
|
if err := dnsProxy.Start(middleDev); err != nil {
|
||||||
logger.Error("Failed to start DNS proxy: %v", err)
|
logger.Error("Failed to start DNS proxy: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -458,7 +460,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
|
||||||
// Use filtered device instead of raw TUN device
|
// Use filtered device instead of raw TUN device
|
||||||
dev = device.NewDevice(filteredDev, sharedBind, (*device.Logger)(wgLogger))
|
dev = device.NewDevice(middleDev, sharedBind, (*device.Logger)(wgLogger))
|
||||||
|
|
||||||
// uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
// uapiListener, err = uapiListen(interfaceName, fileUAPI)
|
||||||
// if err != nil {
|
// if err != nil {
|
||||||
@@ -1067,13 +1069,13 @@ func Close() {
|
|||||||
|
|
||||||
// Stop DNS proxy
|
// Stop DNS proxy
|
||||||
if dnsProxy != nil {
|
if dnsProxy != nil {
|
||||||
dnsProxy.Stop(filteredDev)
|
dnsProxy.Stop(middleDev)
|
||||||
dnsProxy = nil
|
dnsProxy = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear filtered device
|
// Clear filtered device
|
||||||
if filteredDev != nil {
|
if middleDev != nil {
|
||||||
filteredDev = nil
|
middleDev = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close TUN device
|
// Close TUN device
|
||||||
|
|||||||
Reference in New Issue
Block a user