mirror of
https://github.com/fosrl/olm.git
synced 2026-02-21 20:36:42 +00:00
loser to workinr?
This commit is contained in:
237
olm/device_filter.go
Normal file
237
olm/device_filter.go
Normal file
@@ -0,0 +1,237 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// PacketHandler processes intercepted packets and returns true if packet should be dropped
|
||||
type PacketHandler func(packet []byte) bool
|
||||
|
||||
// FilterRule defines a rule for packet filtering
|
||||
type FilterRule struct {
|
||||
DestIP netip.Addr
|
||||
Handler PacketHandler
|
||||
}
|
||||
|
||||
// FilteredDevice wraps a TUN device with packet filtering capabilities
|
||||
type FilteredDevice struct {
|
||||
tun.Device
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewFilteredDevice creates a new filtered TUN device wrapper
|
||||
func NewFilteredDevice(device tun.Device) *FilteredDevice {
|
||||
return &FilteredDevice{
|
||||
Device: device,
|
||||
rules: make([]FilterRule, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// AddRule adds a packet filtering rule
|
||||
func (d *FilteredDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
d.rules = append(d.rules, FilterRule{
|
||||
DestIP: destIP,
|
||||
Handler: handler,
|
||||
})
|
||||
}
|
||||
|
||||
// RemoveRule removes all rules for a given destination IP
|
||||
func (d *FilteredDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.mutex.Lock()
|
||||
defer d.mutex.Unlock()
|
||||
newRules := make([]FilterRule, 0, len(d.rules))
|
||||
for _, rule := range d.rules {
|
||||
if rule.DestIP != destIP {
|
||||
newRules = append(newRules, rule)
|
||||
}
|
||||
}
|
||||
d.rules = newRules
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
if len(packet) < 20 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 16-19 for IPv4
|
||||
ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||
return ip, true
|
||||
case 6:
|
||||
if len(packet) < 40 {
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
// Destination IP is at bytes 24-39 for IPv6
|
||||
var ip16 [16]byte
|
||||
copy(ip16[:], packet[24:40])
|
||||
ip := netip.AddrFrom16(ip16)
|
||||
return ip, true
|
||||
}
|
||||
|
||||
return netip.Addr{}, false
|
||||
}
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
n, err = d.Device.Read(bufs, sizes, offset)
|
||||
if err != nil || n == 0 {
|
||||
return n, err
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
writeIdx := 0
|
||||
for readIdx := 0; readIdx < n; readIdx++ {
|
||||
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
|
||||
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
// Keep packet
|
||||
if writeIdx != readIdx {
|
||||
bufs[writeIdx] = bufs[readIdx]
|
||||
sizes[writeIdx] = sizes[readIdx]
|
||||
}
|
||||
writeIdx++
|
||||
}
|
||||
}
|
||||
|
||||
return writeIdx, err
|
||||
}
|
||||
|
||||
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
|
||||
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||
d.mutex.RLock()
|
||||
rules := d.rules
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return d.Device.Write(bufs, offset)
|
||||
}
|
||||
|
||||
// Filter packets going down
|
||||
filteredBufs := make([][]byte, 0, len(bufs))
|
||||
for _, buf := range bufs {
|
||||
if len(buf) <= offset {
|
||||
continue
|
||||
}
|
||||
|
||||
packet := buf[offset:]
|
||||
destIP, ok := extractDestIP(packet)
|
||||
if !ok {
|
||||
// Can't parse, keep packet
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if packet matches any rule
|
||||
handled := false
|
||||
for _, rule := range rules {
|
||||
if rule.DestIP == destIP {
|
||||
if rule.Handler(packet) {
|
||||
// Packet was handled and should be dropped
|
||||
handled = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !handled {
|
||||
filteredBufs = append(filteredBufs, buf)
|
||||
}
|
||||
}
|
||||
|
||||
if len(filteredBufs) == 0 {
|
||||
return len(bufs), nil // All packets were handled
|
||||
}
|
||||
|
||||
return d.Device.Write(filteredBufs, offset)
|
||||
}
|
||||
|
||||
// GetProtocol returns protocol number from IPv4 packet (fast path)
|
||||
func GetProtocol(packet []byte) (uint8, bool) {
|
||||
if len(packet) < 20 {
|
||||
return 0, false
|
||||
}
|
||||
version := packet[0] >> 4
|
||||
if version == 4 {
|
||||
return packet[9], true
|
||||
} else if version == 6 {
|
||||
if len(packet) < 40 {
|
||||
return 0, false
|
||||
}
|
||||
return packet[6], true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// GetDestPort returns destination port from TCP/UDP packet (fast path)
|
||||
func GetDestPort(packet []byte) (uint16, bool) {
|
||||
if len(packet) < 20 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
version := packet[0] >> 4
|
||||
var headerLen int
|
||||
|
||||
if version == 4 {
|
||||
ihl := packet[0] & 0x0F
|
||||
headerLen = int(ihl) * 4
|
||||
if len(packet) < headerLen+4 {
|
||||
return 0, false
|
||||
}
|
||||
} else if version == 6 {
|
||||
headerLen = 40
|
||||
if len(packet) < headerLen+4 {
|
||||
return 0, false
|
||||
}
|
||||
} else {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// Destination port is at bytes 2-3 of TCP/UDP header
|
||||
port := binary.BigEndian.Uint16(packet[headerLen+2 : headerLen+4])
|
||||
return port, true
|
||||
}
|
||||
100
olm/device_filter_test.go
Normal file
100
olm/device_filter_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractDestIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantIP string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "IPv4 packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30
|
||||
},
|
||||
wantIP: "10.30.30.30",
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short packet",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantIP: "",
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotIP, gotOk := extractDestIP(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if tt.wantOk {
|
||||
wantAddr := netip.MustParseAddr(tt.wantIP)
|
||||
if gotIP != wantAddr {
|
||||
t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetProtocol(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
packet []byte
|
||||
wantProto uint8
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "UDP packet",
|
||||
packet: []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
},
|
||||
wantProto: 17,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "Too short",
|
||||
packet: []byte{0x45, 0x00},
|
||||
wantProto: 0,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotProto, gotOk := GetProtocol(tt.packet)
|
||||
if gotOk != tt.wantOk {
|
||||
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
|
||||
return
|
||||
}
|
||||
if gotProto != tt.wantProto {
|
||||
t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkExtractDestIP(b *testing.B) {
|
||||
packet := []byte{
|
||||
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
|
||||
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
|
||||
0x0a, 0x1e, 0x1e, 0x1e,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
extractDestIP(packet)
|
||||
}
|
||||
}
|
||||
300
olm/dns_proxy.go
Normal file
300
olm/dns_proxy.go
Normal file
@@ -0,0 +1,300 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
const (
|
||||
// DNS proxy listening address
|
||||
DNSProxyIP = "10.30.30.30"
|
||||
DNSPort = 53
|
||||
|
||||
// Upstream DNS servers
|
||||
UpstreamDNS1 = "8.8.8.8:53"
|
||||
UpstreamDNS2 = "8.8.4.4:53"
|
||||
)
|
||||
|
||||
// DNSProxy implements a DNS proxy using gvisor netstack
|
||||
type DNSProxy struct {
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
proxyIP netip.Addr
|
||||
mtu int
|
||||
tunDevice tun.Device // Direct reference to underlying TUN device for responses
|
||||
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewDNSProxy creates a new DNS proxy
|
||||
func NewDNSProxy(tunDevice tun.Device, mtu int) (*DNSProxy, error) {
|
||||
proxyIP, err := netip.ParseAddr(DNSProxyIP)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid proxy IP: %w", err)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
proxy := &DNSProxy{
|
||||
proxyIP: proxyIP,
|
||||
mtu: mtu,
|
||||
tunDevice: tunDevice,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
proxy.ep = channel.New(256, uint32(mtu), "")
|
||||
proxy.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := proxy.stack.CreateNIC(1, proxy.ep); err != nil {
|
||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := proxy.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return nil, fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
proxy.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
return proxy, nil
|
||||
}
|
||||
|
||||
// Start starts the DNS proxy and registers with the filter
|
||||
func (p *DNSProxy) Start(filter *FilteredDevice) error {
|
||||
// Install packet filter rule
|
||||
filter.AddRule(p.proxyIP, p.handlePacket)
|
||||
|
||||
// Start DNS listener
|
||||
p.wg.Add(2)
|
||||
go p.runDNSListener()
|
||||
go p.runPacketSender()
|
||||
|
||||
logger.Info("DNS proxy started on %s:%d", DNSProxyIP, DNSPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the DNS proxy
|
||||
func (p *DNSProxy) Stop(filter *FilteredDevice) {
|
||||
if filter != nil {
|
||||
filter.RemoveRule(p.proxyIP)
|
||||
}
|
||||
p.cancel()
|
||||
p.wg.Wait()
|
||||
|
||||
if p.stack != nil {
|
||||
p.stack.Close()
|
||||
}
|
||||
if p.ep != nil {
|
||||
p.ep.Close()
|
||||
}
|
||||
|
||||
logger.Info("DNS proxy stopped")
|
||||
}
|
||||
|
||||
// handlePacket is called by the filter for packets destined to DNS proxy IP
|
||||
func (p *DNSProxy) handlePacket(packet []byte) bool {
|
||||
if len(packet) < 20 {
|
||||
return false // Don't drop, malformed
|
||||
}
|
||||
|
||||
// Quick check for UDP port 53
|
||||
proto, ok := GetProtocol(packet)
|
||||
if !ok || proto != 17 { // 17 = UDP
|
||||
return false // Not UDP, don't handle
|
||||
}
|
||||
|
||||
port, ok := GetDestPort(packet)
|
||||
if !ok || port != DNSPort {
|
||||
return false // Not DNS port
|
||||
}
|
||||
|
||||
// Inject packet into our netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
p.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
p.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Drop packet from normal path
|
||||
}
|
||||
|
||||
// runDNSListener listens for DNS queries on the netstack
|
||||
func (p *DNSProxy) runDNSListener() {
|
||||
defer p.wg.Done()
|
||||
|
||||
// Create UDP listener using gonet
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4([4]byte{10, 30, 30, 30}),
|
||||
Port: DNSPort,
|
||||
}
|
||||
|
||||
udpConn, err := gonet.DialUDP(p.stack, laddr, nil, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS listener: %v", err)
|
||||
return
|
||||
}
|
||||
defer udpConn.Close()
|
||||
|
||||
logger.Debug("DNS proxy listening on netstack")
|
||||
|
||||
// Handle DNS queries
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
udpConn.SetReadDeadline(time.Now().Add(1 * time.Second))
|
||||
n, remoteAddr, err := udpConn.ReadFrom(buf)
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
continue
|
||||
}
|
||||
if p.ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
logger.Error("DNS read error: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
query := make([]byte, n)
|
||||
copy(query, buf[:n])
|
||||
|
||||
// Handle query in background
|
||||
go p.forwardDNSQuery(udpConn, query, remoteAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// forwardDNSQuery forwards a DNS query to upstream DNS server
|
||||
func (p *DNSProxy) forwardDNSQuery(udpConn *gonet.UDPConn, query []byte, clientAddr net.Addr) {
|
||||
// Try primary DNS server
|
||||
response, err := p.queryUpstream(UpstreamDNS1, query, 2*time.Second)
|
||||
if err != nil {
|
||||
// Try secondary DNS server
|
||||
logger.Debug("Primary DNS failed, trying secondary: %v", err)
|
||||
response, err = p.queryUpstream(UpstreamDNS2, query, 2*time.Second)
|
||||
if err != nil {
|
||||
logger.Error("Both DNS servers failed: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Send response back to client through netstack
|
||||
_, err = udpConn.WriteTo(response, clientAddr)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send DNS response: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// queryUpstream sends a DNS query to upstream server
|
||||
func (p *DNSProxy) queryUpstream(server string, query []byte, timeout time.Duration) ([]byte, error) {
|
||||
conn, err := net.DialTimeout("udp", server, timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
conn.SetDeadline(time.Now().Add(timeout))
|
||||
|
||||
if _, err := conn.Write(query); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := make([]byte, 4096)
|
||||
n, err := conn.Read(response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response[:n], nil
|
||||
}
|
||||
|
||||
// runPacketSender sends packets from netstack back to TUN
|
||||
func (p *DNSProxy) runPacketSender() {
|
||||
defer p.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-p.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Read packets from netstack endpoint
|
||||
pkt := p.ep.Read()
|
||||
if pkt == nil {
|
||||
// No packet available, small sleep to avoid busy loop
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert packet to bytes
|
||||
view := pkt.ToView()
|
||||
packetData := view.AsSlice()
|
||||
|
||||
// Make a copy and write directly back to the TUN device
|
||||
// This bypasses WireGuard - the packet goes straight back to the host
|
||||
buf := make([]byte, len(packetData))
|
||||
copy(buf, packetData)
|
||||
|
||||
// Write packet back to TUN device
|
||||
bufs := [][]byte{buf}
|
||||
_, err := p.tunDevice.Write(bufs, 0)
|
||||
if err != nil {
|
||||
logger.Error("Failed to write DNS response to TUN: %v", err)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
111
olm/example_extension.go.template
Normal file
111
olm/example_extension.go.template
Normal file
@@ -0,0 +1,111 @@
|
||||
package olm
|
||||
|
||||
// This file demonstrates how to add additional virtual services using the FilteredDevice infrastructure
|
||||
// Copy and modify this template to add new services
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"sync"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
// Example: Simple echo server on 10.30.30.50:7777
|
||||
|
||||
const (
|
||||
EchoProxyIP = "10.30.30.50"
|
||||
EchoProxyPort = 7777
|
||||
)
|
||||
|
||||
// EchoProxy implements a simple echo server
|
||||
type EchoProxy struct {
|
||||
proxyIP netip.Addr
|
||||
tunDevice tun.Device
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewEchoProxy creates a new echo proxy instance
|
||||
func NewEchoProxy(tunDevice tun.Device) (*EchoProxy, error) {
|
||||
proxyIP := netip.MustParseAddr(EchoProxyIP)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
return &EchoProxy{
|
||||
proxyIP: proxyIP,
|
||||
tunDevice: tunDevice,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start registers the proxy with the filter
|
||||
func (e *EchoProxy) Start(filter *FilteredDevice) error {
|
||||
filter.AddRule(e.proxyIP, e.handlePacket)
|
||||
logger.Info("Echo proxy started on %s:%d", EchoProxyIP, EchoProxyPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop unregisters the proxy
|
||||
func (e *EchoProxy) Stop(filter *FilteredDevice) {
|
||||
if filter != nil {
|
||||
filter.RemoveRule(e.proxyIP)
|
||||
}
|
||||
e.cancel()
|
||||
e.wg.Wait()
|
||||
logger.Info("Echo proxy stopped")
|
||||
}
|
||||
|
||||
// handlePacket processes packets destined for the echo server
|
||||
func (e *EchoProxy) handlePacket(packet []byte) bool {
|
||||
// Quick validation
|
||||
if len(packet) < 20 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check protocol (UDP)
|
||||
proto, ok := GetProtocol(packet)
|
||||
if !ok || proto != 17 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check port
|
||||
port, ok := GetDestPort(packet)
|
||||
if !ok || port != EchoProxyPort {
|
||||
return false
|
||||
}
|
||||
|
||||
// For a real implementation, you would:
|
||||
// 1. Parse the UDP packet
|
||||
// 2. Extract the payload
|
||||
// 3. Create a response packet with swapped src/dest
|
||||
// 4. Write response back to TUN device
|
||||
|
||||
logger.Debug("Echo proxy received packet (would echo back)")
|
||||
|
||||
// Return true to drop packet from normal WireGuard path
|
||||
return true
|
||||
}
|
||||
|
||||
// Example integration in olm.go:
|
||||
//
|
||||
// var echoProxy *EchoProxy
|
||||
//
|
||||
// // During tunnel setup (after creating filteredDev):
|
||||
// echoProxy, err = NewEchoProxy(tdev)
|
||||
// if err != nil {
|
||||
// logger.Error("Failed to create echo proxy: %v", err)
|
||||
// return
|
||||
// }
|
||||
// if err := echoProxy.Start(filteredDev); err != nil {
|
||||
// logger.Error("Failed to start echo proxy: %v", err)
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// // During tunnel teardown:
|
||||
// if echoProxy != nil {
|
||||
// echoProxy.Stop(filteredDev)
|
||||
// echoProxy = nil
|
||||
// }
|
||||
48
olm/olm.go
48
olm/olm.go
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/fosrl/olm/api"
|
||||
"github.com/fosrl/olm/network"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/tunfilter"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
@@ -71,6 +70,8 @@ var (
|
||||
holePunchData HolePunchData
|
||||
uapiListener net.Listener
|
||||
tdev tun.Device
|
||||
filteredDev *FilteredDevice
|
||||
dnsProxy *DNSProxy
|
||||
apiServer *api.API
|
||||
olmClient *websocket.Client
|
||||
tunnelCancel context.CancelFunc
|
||||
@@ -82,12 +83,6 @@ var (
|
||||
globalCtx context.Context
|
||||
stopRegister func()
|
||||
stopPing chan struct{}
|
||||
|
||||
// Packet interceptor components
|
||||
filteredDev *tunfilter.FilteredDevice
|
||||
packetInjector *tunfilter.PacketInjector
|
||||
interceptorManager *tunfilter.InterceptorManager
|
||||
ipFilter *tunfilter.IPFilter
|
||||
)
|
||||
|
||||
func Init(ctx context.Context, config GlobalConfig) {
|
||||
@@ -431,15 +426,19 @@ func StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
// Create packet injector for the TUN device
|
||||
packetInjector = tunfilter.NewPacketInjector(tdev)
|
||||
// Wrap TUN device with packet filter for DNS proxy
|
||||
filteredDev = NewFilteredDevice(tdev)
|
||||
|
||||
// Create interceptor manager
|
||||
interceptorManager = tunfilter.NewInterceptorManager(packetInjector)
|
||||
|
||||
// Create an interceptor filter and wrap the TUN device
|
||||
interceptorFilter := tunfilter.NewInterceptorFilter(interceptorManager)
|
||||
filteredDev = tunfilter.NewFilteredDevice(tdev, interceptorFilter)
|
||||
// Create and start DNS proxy
|
||||
dnsProxy, err = NewDNSProxy(tdev, config.MTU)
|
||||
if err != nil {
|
||||
logger.Error("Failed to create DNS proxy: %v", err)
|
||||
return
|
||||
}
|
||||
if err := dnsProxy.Start(filteredDev); err != nil {
|
||||
logger.Error("Failed to start DNS proxy: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// fileUAPI, err := func() (*os.File, error) {
|
||||
// if config.FileDescriptorUAPI != 0 {
|
||||
@@ -1066,26 +1065,17 @@ func Close() {
|
||||
dev = nil
|
||||
}
|
||||
|
||||
// Stop packet injector
|
||||
if packetInjector != nil {
|
||||
packetInjector.Stop()
|
||||
packetInjector = nil
|
||||
// Stop DNS proxy
|
||||
if dnsProxy != nil {
|
||||
dnsProxy.Stop(filteredDev)
|
||||
dnsProxy = nil
|
||||
}
|
||||
|
||||
// Stop interceptor manager
|
||||
if interceptorManager != nil {
|
||||
interceptorManager.Stop()
|
||||
interceptorManager = nil
|
||||
}
|
||||
|
||||
// Clear packet filter
|
||||
// Clear filtered device
|
||||
if filteredDev != nil {
|
||||
filteredDev.SetFilter(nil)
|
||||
filteredDev = nil
|
||||
}
|
||||
|
||||
ipFilter = nil
|
||||
|
||||
// Close TUN device
|
||||
if tdev != nil {
|
||||
tdev.Close()
|
||||
|
||||
Reference in New Issue
Block a user