loser to workinr?

This commit is contained in:
Owen
2025-11-21 14:17:23 -05:00
parent f882cd983b
commit 04f7778765
17 changed files with 1170 additions and 982 deletions

237
olm/device_filter.go Normal file
View File

@@ -0,0 +1,237 @@
package olm
import (
"encoding/binary"
"net/netip"
"sync"
"golang.zx2c4.com/wireguard/tun"
)
// PacketHandler processes intercepted packets and returns true if packet should be dropped
type PacketHandler func(packet []byte) bool
// FilterRule defines a rule for packet filtering
type FilterRule struct {
DestIP netip.Addr
Handler PacketHandler
}
// FilteredDevice wraps a TUN device with packet filtering capabilities
type FilteredDevice struct {
tun.Device
rules []FilterRule
mutex sync.RWMutex
}
// NewFilteredDevice creates a new filtered TUN device wrapper
func NewFilteredDevice(device tun.Device) *FilteredDevice {
return &FilteredDevice{
Device: device,
rules: make([]FilterRule, 0),
}
}
// AddRule adds a packet filtering rule
func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rules = append(d.rules, FilterRule{
DestIP: destIP,
Handler: handler,
})
}
// RemoveRule removes all rules for a given destination IP
func (d *FilteredDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock()
defer d.mutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules))
for _, rule := range d.rules {
if rule.DestIP != destIP {
newRules = append(newRules, rule)
}
}
d.rules = newRules
}
// extractDestIP extracts destination IP from packet (fast path)
func extractDestIP(packet []byte) (netip.Addr, bool) {
if len(packet) < 20 {
return netip.Addr{}, false
}
version := packet[0] >> 4
switch version {
case 4:
if len(packet) < 20 {
return netip.Addr{}, false
}
// Destination IP is at bytes 16-19 for IPv4
ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
return ip, true
case 6:
if len(packet) < 40 {
return netip.Addr{}, false
}
// Destination IP is at bytes 24-39 for IPv6
var ip16 [16]byte
copy(ip16[:], packet[24:40])
ip := netip.AddrFrom16(ip16)
return ip, true
}
return netip.Addr{}, false
}
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
n, err = d.Device.Read(bufs, sizes, offset)
if err != nil || n == 0 {
return n, err
}
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return n, err
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
}
}
if !handled {
// Keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
}
}
return writeIdx, err
}
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return d.Device.Write(bufs, offset)
}
// Filter packets going down
filteredBufs := make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue
}
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
filteredBufs = append(filteredBufs, buf)
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
}
}
if !handled {
filteredBufs = append(filteredBufs, buf)
}
}
if len(filteredBufs) == 0 {
return len(bufs), nil // All packets were handled
}
return d.Device.Write(filteredBufs, offset)
}
// GetProtocol returns protocol number from IPv4 packet (fast path)
func GetProtocol(packet []byte) (uint8, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
if version == 4 {
return packet[9], true
} else if version == 6 {
if len(packet) < 40 {
return 0, false
}
return packet[6], true
}
return 0, false
}
// GetDestPort returns destination port from TCP/UDP packet (fast path)
func GetDestPort(packet []byte) (uint16, bool) {
if len(packet) < 20 {
return 0, false
}
version := packet[0] >> 4
var headerLen int
if version == 4 {
ihl := packet[0] & 0x0F
headerLen = int(ihl) * 4
if len(packet) < headerLen+4 {
return 0, false
}
} else if version == 6 {
headerLen = 40
if len(packet) < headerLen+4 {
return 0, false
}
} else {
return 0, false
}
// Destination port is at bytes 2-3 of TCP/UDP header
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
return port, true
}

100
olm/device_filter_test.go Normal file
View File

@@ -0,0 +1,100 @@
package olm
import (
"net/netip"
"testing"
)
func TestExtractDestIP(t *testing.T) {
tests := []struct {
name string
packet []byte
wantIP string
wantOk bool
}{
{
name: "IPv4 packet",
packet: []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30
},
wantIP: "10.30.30.30",
wantOk: true,
},
{
name: "Too short packet",
packet: []byte{0x45, 0x00},
wantIP: "",
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotIP, gotOk := extractDestIP(tt.packet)
if gotOk != tt.wantOk {
t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk)
return
}
if tt.wantOk {
wantAddr := netip.MustParseAddr(tt.wantIP)
if gotIP != wantAddr {
t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr)
}
}
})
}
}
func TestGetProtocol(t *testing.T) {
tests := []struct {
name string
packet []byte
wantProto uint8
wantOk bool
}{
{
name: "UDP packet",
packet: []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9
0x0a, 0x1e, 0x1e, 0x1e,
},
wantProto: 17,
wantOk: true,
},
{
name: "Too short",
packet: []byte{0x45, 0x00},
wantProto: 0,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotProto, gotOk := GetProtocol(tt.packet)
if gotOk != tt.wantOk {
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
return
}
if gotProto != tt.wantProto {
t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto)
}
})
}
}
func BenchmarkExtractDestIP(b *testing.B) {
packet := []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
0x0a, 0x1e, 0x1e, 0x1e,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
extractDestIP(packet)
}
}

300
olm/dns_proxy.go Normal file
View File

@@ -0,0 +1,300 @@
package olm
import (
"context"
"fmt"
"net"
"net/netip"
"sync"
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
const (
// DNS proxy listening address
DNSProxyIP = "10.30.30.30"
DNSPort = 53
// Upstream DNS servers
UpstreamDNS1 = "8.8.8.8:53"
UpstreamDNS2 = "8.8.4.4:53"
)
// DNSProxy implements a DNS proxy using gvisor netstack
type DNSProxy struct {
stack *stack.Stack
ep *channel.Endpoint
proxyIP netip.Addr
mtu int
tunDevice tun.Device // Direct reference to underlying TUN device for responses
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
mutex sync.RWMutex
}
// NewDNSProxy creates a new DNS proxy
func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
proxyIP, err := netip.ParseAddr(DNSProxyIP)
if err != nil {
return nil, fmt.Errorf("invalid proxy IP: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
proxy := &DNSProxy{
proxyIP: proxyIP,
mtu: mtu,
tunDevice: tunDevice,
ctx: ctx,
cancel: cancel,
}
// Create gvisor netstack
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: true,
}
proxy.ep = channel.New(256, uint32(mtu), "")
proxy.stack = stack.New(stackOpts)
// Create NIC
if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil {
return nil, fmt.Errorf("failed to create NIC: %v", err)
}
// Add IP address
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}).WithPrefix(),
}
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %v", err)
}
// Add default route
proxy.stack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
return proxy, nil
}
// Start starts the DNS proxy and registers with the filter
func (p *DNSProxy) Start(filter *FilteredDevice) error {
// Install packet filter rule
filter.AddRule(p.proxyIP, p.handlePacket)
// Start DNS listener
p.wg.Add(2)
go p.runDNSListener()
go p.runPacketSender()
logger.Info("DNS proxy started on %s:%d", DNSProxyIP, DNSPort)
return nil
}
// Stop stops the DNS proxy
func (p *DNSProxy) Stop(filter *FilteredDevice) {
if filter != nil {
filter.RemoveRule(p.proxyIP)
}
p.cancel()
p.wg.Wait()
if p.stack != nil {
p.stack.Close()
}
if p.ep != nil {
p.ep.Close()
}
logger.Info("DNS proxy stopped")
}
// handlePacket is called by the filter for packets destined to DNS proxy IP
func (p *DNSProxy) handlePacket(packet []byte) bool {
if len(packet) < 20 {
return false // Don't drop, malformed
}
// Quick check for UDP port 53
proto, ok := GetProtocol(packet)
if !ok || proto != 17 { // 17 = UDP
return false // Not UDP, don't handle
}
port, ok := GetDestPort(packet)
if !ok || port != DNSPort {
return false // Not DNS port
}
// Inject packet into our netstack
version := packet[0] >> 4
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
switch version {
case 4:
p.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
case 6:
p.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
default:
pkb.DecRef()
return false
}
pkb.DecRef()
return true // Drop packet from normal path
}
// runDNSListener listens for DNS queries on the netstack
func (p *DNSProxy) runDNSListener() {
defer p.wg.Done()
// Create UDP listener using gonet
laddr := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}),
Port: DNSPort,
}
udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber)
if err != nil {
logger.Error("Failed to create DNS listener: %v", err)
return
}
defer udpConn.Close()
logger.Debug("DNS proxy listening on netstack")
// Handle DNS queries
buf := make([]byte, 4096)
for {
select {
case <-p.ctx.Done():
return
default:
}
udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
n, remoteAddr, err := udpConn.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
}
if p.ctx.Err() != nil {
return
}
logger.Error("DNS read error: %v", err)
continue
}
query := make([]byte, n)
copy(query, buf[:n])
// Handle query in background
go p.forwardDNSQuery(udpConn, query, remoteAddr)
}
}
// forwardDNSQuery forwards a DNS query to upstream DNS server
func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) {
// Try primary DNS server
response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second)
if err != nil {
// Try secondary DNS server
logger.Debug("Primary DNS failed, trying secondary: %v", err)
response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second)
if err != nil {
logger.Error("Both DNS servers failed: %v", err)
return
}
}
// Send response back to client through netstack
_, err = udpConn.WriteTo(response, clientAddr)
if err != nil {
logger.Error("Failed to send DNS response: %v", err)
}
}
// queryUpstream sends a DNS query to upstream server
func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) {
conn, err := net.DialTimeout("udp", server, timeout)
if err != nil {
return nil, err
}
defer conn.Close()
conn.SetDeadline(time.Now().Add(timeout))
if _, err := conn.Write(query); err != nil {
return nil, err
}
response := make([]byte, 4096)
n, err := conn.Read(response)
if err != nil {
return nil, err
}
return response[:n], nil
}
// runPacketSender sends packets from netstack back to TUN
func (p *DNSProxy) runPacketSender() {
defer p.wg.Done()
for {
select {
case <-p.ctx.Done():
return
default:
}
// Read packets from netstack endpoint
pkt := p.ep.Read()
if pkt == nil {
// No packet available, small sleep to avoid busy loop
time.Sleep(1 * time.Millisecond)
continue
}
// Convert packet to bytes
view := pkt.ToView()
packetData := view.AsSlice()
// Make a copy and write directly back to the TUN device
// This bypasses WireGuard - the packet goes straight back to the host
buf := make([]byte, len(packetData))
copy(buf, packetData)
// Write packet back to TUN device
bufs := [][]byte{buf}
_, err := p.tunDevice.Write(bufs, 0)
if err != nil {
logger.Error("Failed to write DNS response to TUN: %v", err)
}
pkt.DecRef()
}
}

View File

@@ -0,0 +1,111 @@
package olm
// This file demonstrates how to add additional virtual services using the FilteredDevice infrastructure
// Copy and modify this template to add new services
import (
"context"
"net/netip"
"sync"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun"
)
// Example: Simple echo server on 10.30.30.50:7777
const (
EchoProxyIP = "10.30.30.50"
EchoProxyPort = 7777
)
// EchoProxy implements a simple echo server
type EchoProxy struct {
proxyIP netip.Addr
tunDevice tun.Device
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
}
// NewEchoProxy creates a new echo proxy instance
func NewEchoProxy(tunDevice tun.Device) (*EchoProxy, error) {
proxyIP := netip.MustParseAddr(EchoProxyIP)
ctx, cancel := context.WithCancel(context.Background())
return &EchoProxy{
proxyIP: proxyIP,
tunDevice: tunDevice,
ctx: ctx,
cancel: cancel,
}, nil
}
// Start registers the proxy with the filter
func (e *EchoProxy) Start(filter *FilteredDevice) error {
filter.AddRule(e.proxyIP, e.handlePacket)
logger.Info("Echo proxy started on %s:%d", EchoProxyIP, EchoProxyPort)
return nil
}
// Stop unregisters the proxy
func (e *EchoProxy) Stop(filter *FilteredDevice) {
if filter != nil {
filter.RemoveRule(e.proxyIP)
}
e.cancel()
e.wg.Wait()
logger.Info("Echo proxy stopped")
}
// handlePacket processes packets destined for the echo server
func (e *EchoProxy) handlePacket(packet []byte) bool {
// Quick validation
if len(packet) < 20 {
return false
}
// Check protocol (UDP)
proto, ok := GetProtocol(packet)
if !ok || proto != 17 {
return false
}
// Check port
port, ok := GetDestPort(packet)
if !ok || port != EchoProxyPort {
return false
}
// For a real implementation, you would:
// 1. Parse the UDP packet
// 2. Extract the payload
// 3. Create a response packet with swapped src/dest
// 4. Write response back to TUN device
logger.Debug("Echo proxy received packet (would echo back)")
// Return true to drop packet from normal WireGuard path
return true
}
// Example integration in olm.go:
//
// var echoProxy *EchoProxy
//
// // During tunnel setup (after creating filteredDev):
// echoProxy, err = NewEchoProxy(tdev)
// if err != nil {
// logger.Error("Failed to create echo proxy: %v", err)
// return
// }
// if err := echoProxy.Start(filteredDev); err != nil {
// logger.Error("Failed to start echo proxy: %v", err)
// return
// }
//
// // During tunnel teardown:
// if echoProxy != nil {
// echoProxy.Stop(filteredDev)
// echoProxy = nil
// }

View File

@@ -15,7 +15,6 @@ import (
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/network"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/tunfilter"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@@ -71,6 +70,8 @@ var (
holePunchData HolePunchData
uapiListener net.Listener
tdev tun.Device
filteredDev *FilteredDevice
dnsProxy *DNSProxy
apiServer *api.API
olmClient *websocket.Client
tunnelCancel context.CancelFunc
@@ -82,12 +83,6 @@ var (
globalCtx context.Context
stopRegister func()
stopPing chan struct{}
// Packet interceptor components
filteredDev *tunfilter.FilteredDevice
packetInjector *tunfilter.PacketInjector
interceptorManager *tunfilter.InterceptorManager
ipFilter *tunfilter.IPFilter
)
func Init(ctx context.Context, config GlobalConfig) {
@@ -431,15 +426,19 @@ func StartTunnel(config TunnelConfig) {
}
}
// Create packet injector for the TUN device
packetInjector = tunfilter.NewPacketInjector(tdev)
// Wrap TUN device with packet filter for DNS proxy
filteredDev = NewFilteredDevice(tdev)
// Create interceptor manager
interceptorManager = tunfilter.NewInterceptorManager(packetInjector)
// Create an interceptor filter and wrap the TUN device
interceptorFilter := tunfilter.NewInterceptorFilter(interceptorManager)
filteredDev = tunfilter.NewFilteredDevice(tdev, interceptorFilter)
// Create and start DNS proxy
dnsProxy, err = NewDNSProxy(tdev, config.MTU)
if err != nil {
logger.Error("Failed to create DNS proxy: %v", err)
return
}
if err := dnsProxy.Start(filteredDev); err != nil {
logger.Error("Failed to start DNS proxy: %v", err)
return
}
// fileUAPI, err := func() (*os.File, error) {
// if config.FileDescriptorUAPI != 0 {
@@ -1066,26 +1065,17 @@ func Close() {
dev = nil
}
// Stop packet injector
if packetInjector != nil {
packetInjector.Stop()
packetInjector = nil
// Stop DNS proxy
if dnsProxy != nil {
dnsProxy.Stop(filteredDev)
dnsProxy = nil
}
// Stop interceptor manager
if interceptorManager != nil {
interceptorManager.Stop()
interceptorManager = nil
}
// Clear packet filter
// Clear filtered device
if filteredDev != nil {
filteredDev.SetFilter(nil)
filteredDev = nil
}
ipFilter = nil
// Close TUN device
if tdev != nil {
tdev.Close()