diff --git a/go.mod b/go.mod index 0c16b81..890f439 100644 --- a/go.mod +++ b/go.mod @@ -7,19 +7,21 @@ require ( github.com/fosrl/newt v0.0.0 github.com/gorilla/websocket v1.5.3 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.44.0 - golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 golang.org/x/sys v0.38.0 golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 + gvisor.dev/gvisor v0.0.0-20251121015435-2879878b845e software.sslmate.com/src/go-pkcs12 v0.6.0 ) require ( + github.com/google/btree v1.1.3 // indirect github.com/vishvananda/netns v0.0.5 // indirect + golang.org/x/crypto v0.44.0 // indirect + golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect golang.org/x/net v0.47.0 // indirect + golang.org/x/time v0.12.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect ) replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index d2dbb17..3045aa6 100644 --- a/go.sum +++ b/go.sum @@ -14,6 +14,8 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU= golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0= golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -28,7 +30,7 @@ golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+Z golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= -gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c= -gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs= +gvisor.dev/gvisor v0.0.0-20251121015435-2879878b845e h1:upyNwibTehzZl2FY2LEQ6bTRKOrU0IMiBLiIKT+dKF0= +gvisor.dev/gvisor v0.0.0-20251121015435-2879878b845e/go.mod h1:W1ZgZ/Dh85TgSZWH67l2jKVpDE5bjIaut7rjwwOiHzQ= software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU= software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= diff --git a/olm/olm.go b/olm/olm.go index 9803516..5a521f6 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -15,6 +15,7 @@ import ( "github.com/fosrl/olm/api" "github.com/fosrl/olm/network" "github.com/fosrl/olm/peermonitor" + "github.com/fosrl/olm/tunfilter" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -81,6 +82,12 @@ var ( globalCtx context.Context stopRegister func() stopPing chan struct{} + + // Packet interceptor components + filteredDev *tunfilter.FilteredDevice + packetInjector *tunfilter.PacketInjector + interceptorManager *tunfilter.InterceptorManager + ipFilter *tunfilter.IPFilter ) func Init(ctx context.Context, config GlobalConfig) { @@ -424,6 +431,16 @@ func StartTunnel(config TunnelConfig) { } } + // Create packet injector for the TUN device + packetInjector = tunfilter.NewPacketInjector(tdev) + + // Create interceptor manager + interceptorManager = tunfilter.NewInterceptorManager(packetInjector) + + // Create an interceptor filter and wrap the TUN device + interceptorFilter := tunfilter.NewInterceptorFilter(interceptorManager) + filteredDev = tunfilter.NewFilteredDevice(tdev, interceptorFilter) + // fileUAPI, err := func() (*os.File, error) { // if config.FileDescriptorUAPI != 0 { // fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32) @@ -441,7 +458,8 @@ func StartTunnel(config TunnelConfig) { // } wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ") - dev = device.NewDevice(tdev, sharedBind, (*device.Logger)(wgLogger)) + // Use filtered device instead of raw TUN device + dev = device.NewDevice(filteredDev, sharedBind, (*device.Logger)(wgLogger)) // uapiListener, err = uapiListen(interfaceName, fileUAPI) // if err != nil { @@ -1048,6 +1066,26 @@ func Close() { dev = nil } + // Stop packet injector + if packetInjector != nil { + packetInjector.Stop() + packetInjector = nil + } + + // Stop interceptor manager + if interceptorManager != nil { + interceptorManager.Stop() + interceptorManager = nil + } + + // Clear packet filter + if filteredDev != nil { + filteredDev.SetFilter(nil) + filteredDev = nil + } + + ipFilter = nil + // Close TUN device if tdev != nil { tdev.Close() diff --git a/tunfilter/README.md b/tunfilter/README.md new file mode 100644 index 0000000..aa74312 --- /dev/null +++ b/tunfilter/README.md @@ -0,0 +1,215 @@ +# TUN Filter Interceptor System + +An extensible packet filtering and interception framework for the olm TUN device. + +## Architecture + +The system consists of several components that work together: + +``` +┌─────────────────┐ +│ WireGuard │ +└────────┬────────┘ + │ +┌────────▼────────┐ +│ FilteredDevice │ (Wraps TUN device) +└────────┬────────┘ + │ +┌────────▼──────────────┐ +│ InterceptorFilter │ +└────────┬──────────────┘ + │ +┌────────▼──────────────┐ +│ InterceptorManager │ +│ ┌─────────────────┐ │ +│ │ DNS Proxy │ │ +│ ├─────────────────┤ │ +│ │ Future... │ │ +│ └─────────────────┘ │ +└────────┬──────────────┘ + │ +┌────────▼────────┐ +│ TUN Device │ +└─────────────────┘ +``` + +## Components + +### FilteredDevice +- Wraps the TUN device +- Calls packet filters for every packet in both directions +- Located between WireGuard and the TUN device + +### PacketInterceptor Interface +Extensible interface for creating custom packet interceptors: +```go +type PacketInterceptor interface { + Name() string + ShouldIntercept(packet []byte, direction Direction) bool + HandlePacket(ctx context.Context, packet []byte, direction Direction) error + Start(ctx context.Context) error + Stop() error +} +``` + +### InterceptorManager +- Manages multiple interceptors +- Routes packets to the first matching interceptor +- Handles lifecycle (start/stop) for all interceptors + +### PacketInjector +- Allows interceptors to inject response packets +- Writes packets back into the TUN device as if they came from the tunnel + +### DNS Proxy Interceptor +Example implementation that: +- Intercepts DNS queries to `10.30.30.30` +- Forwards them to `8.8.8.8` +- Injects responses back as if they came from `10.30.30.30` + +## Usage + +The system is automatically initialized in `olm.go` when a tunnel is created: + +```go +// Create packet injector for the TUN device +packetInjector = tunfilter.NewPacketInjector(tdev) + +// Create interceptor manager +interceptorManager = tunfilter.NewInterceptorManager(packetInjector) + +// Add DNS proxy interceptor for 10.30.30.30 +dnsProxy := tunfilter.NewDNSProxyInterceptor( + tunfilter.DNSProxyConfig{ + Name: "dns-proxy", + InterceptIP: netip.MustParseAddr("10.30.30.30"), + UpstreamDNS: "8.8.8.8:53", + LocalIP: tunnelIP, + }, + packetInjector, +) + +interceptorManager.AddInterceptor(dnsProxy) + +// Create filter and wrap TUN device +interceptorFilter := tunfilter.NewInterceptorFilter(interceptorManager) +filteredDev = tunfilter.NewFilteredDevice(tdev, interceptorFilter) +``` + +## Adding New Interceptors + +To create a new interceptor: + +1. **Implement the PacketInterceptor interface:** + +```go +type MyInterceptor struct { + name string + injector *tunfilter.PacketInjector + // your fields... +} + +func (i *MyInterceptor) Name() string { + return i.name +} + +func (i *MyInterceptor) ShouldIntercept(packet []byte, direction tunfilter.Direction) bool { + // Quick check: parse packet and decide if you want to handle it + // This is called for EVERY packet, so make it fast! + info, ok := tunfilter.ParsePacket(packet) + if !ok { + return false + } + + // Example: intercept UDP packets to a specific IP and port + return info.IsUDP && info.DstIP == myTargetIP && info.DstPort == myPort +} + +func (i *MyInterceptor) HandlePacket(ctx context.Context, packet []byte, direction tunfilter.Direction) error { + // Process the packet + // You can: + // 1. Extract data from it + // 2. Make external requests + // 3. Inject response packets using i.injector.InjectInbound(responsePacket) + + return nil +} + +func (i *MyInterceptor) Start(ctx context.Context) error { + // Initialize resources (e.g., start listeners, connect to services) + return nil +} + +func (i *MyInterceptor) Stop() error { + // Clean up resources + return nil +} +``` + +2. **Register it with the manager:** + +```go +myInterceptor := NewMyInterceptor(...) +if err := interceptorManager.AddInterceptor(myInterceptor); err != nil { + logger.Error("Failed to add interceptor: %v", err) +} +``` + +## Packet Flow + +### Outbound (Host → Tunnel) +1. Packet written by application +2. TUN device receives it +3. FilteredDevice.Write intercepts it +4. InterceptorFilter checks all interceptors +5. If intercepted: Handler processes it, returns FilterActionIntercept +6. If passed: Packet continues to WireGuard for encryption + +### Inbound (Tunnel → Host) +1. WireGuard decrypts packet +2. FilteredDevice.Read intercepts it +3. InterceptorFilter checks all interceptors +4. If intercepted: Handler processes it, returns FilterActionIntercept +5. If passed: Packet written to TUN device for delivery to host + +## Example: DNS Proxy + +DNS queries to `10.30.30.30:53` are intercepted: + +``` +Application → 10.30.30.30:53 + ↓ + DNSProxyInterceptor + ↓ + Forward to 8.8.8.8:53 + ↓ + Get response + ↓ + Build response packet (src: 10.30.30.30) + ↓ + Inject into TUN device + ↓ + Application receives response +``` + +All other traffic flows normally through the WireGuard tunnel. + +## Future Ideas + +The interceptor system can be extended for: + +- **HTTP Proxy**: Intercept HTTP traffic and route through a proxy +- **Protocol Translation**: Convert one protocol to another +- **Traffic Shaping**: Add delays, simulate packet loss +- **Logging/Monitoring**: Record specific traffic patterns +- **Custom DNS Rules**: Different upstream servers based on domain +- **Local Service Integration**: Route certain IPs to local services +- **mDNS Support**: Handle multicast DNS queries locally + +## Performance Notes + +- `ShouldIntercept()` is called for every packet - keep it fast! +- Use simple checks (IP/port comparisons) +- Avoid allocations in the hot path +- Packet handling runs in a goroutine to avoid blocking +- The filtered device uses zero-copy techniques where possible diff --git a/tunfilter/filter.go b/tunfilter/filter.go new file mode 100644 index 0000000..bb1acfa --- /dev/null +++ b/tunfilter/filter.go @@ -0,0 +1,35 @@ +package tunfilter + +// FilterAction defines what to do with a packet +type FilterAction int + +const ( + // FilterActionPass allows the packet to continue normally + FilterActionPass FilterAction = iota + // FilterActionDrop silently drops the packet + FilterActionDrop + // FilterActionIntercept captures the packet for custom handling + FilterActionIntercept +) + +// PacketFilter interface for filtering and intercepting packets +type PacketFilter interface { + // FilterOutbound filters packets going FROM host TO tunnel (before encryption) + // Return FilterActionPass to allow, FilterActionDrop to drop, FilterActionIntercept to handle + FilterOutbound(packet []byte, size int) FilterAction + + // FilterInbound filters packets coming FROM tunnel TO host (after decryption) + // Return FilterActionPass to allow, FilterActionDrop to drop, FilterActionIntercept to handle + FilterInbound(packet []byte, size int) FilterAction +} + +// HandlerFunc is called when a packet is intercepted +type HandlerFunc func(packet []byte, direction Direction) error + +// Direction indicates packet flow direction +type Direction int + +const ( + DirectionOutbound Direction = iota // Host -> Tunnel + DirectionInbound // Tunnel -> Host +) diff --git a/tunfilter/filter_test.go b/tunfilter/filter_test.go new file mode 100644 index 0000000..830b05a --- /dev/null +++ b/tunfilter/filter_test.go @@ -0,0 +1,159 @@ +package tunfilter_test + +import ( + "encoding/binary" + "net/netip" + "testing" + + "github.com/fosrl/olm/tunfilter" +) + +// TestIPFilter validates the IP-based packet filtering +func TestIPFilter(t *testing.T) { + filter := tunfilter.NewIPFilter() + + // Create a test handler that just tracks calls + handler := func(packet []byte, direction tunfilter.Direction) error { + return nil + } + + // Add IP to intercept + targetIP := netip.MustParseAddr("10.30.30.30") + filter.AddInterceptIP(targetIP, handler) + + // Create a test packet destined for 10.30.30.30 + packet := buildTestPacket( + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("10.30.30.30"), + 12345, + 51821, + ) + + // Filter the packet (outbound direction) + action := filter.FilterOutbound(packet, len(packet)) + + // Should be intercepted + if action != tunfilter.FilterActionIntercept { + t.Errorf("Expected FilterActionIntercept, got %v", action) + } + + // Handler should eventually be called (async) + // In real tests you'd use sync primitives +} + +// TestPacketParsing validates packet information extraction +func TestPacketParsing(t *testing.T) { + srcIP := netip.MustParseAddr("192.168.1.100") + dstIP := netip.MustParseAddr("10.30.30.30") + srcPort := uint16(54321) + dstPort := uint16(51821) + + packet := buildTestPacket(srcIP, dstIP, srcPort, dstPort) + + info, ok := tunfilter.ParsePacket(packet) + if !ok { + t.Fatal("Failed to parse packet") + } + + if info.SrcIP != srcIP { + t.Errorf("Expected src IP %s, got %s", srcIP, info.SrcIP) + } + + if info.DstIP != dstIP { + t.Errorf("Expected dst IP %s, got %s", dstIP, info.DstIP) + } + + if info.SrcPort != srcPort { + t.Errorf("Expected src port %d, got %d", srcPort, info.SrcPort) + } + + if info.DstPort != dstPort { + t.Errorf("Expected dst port %d, got %d", dstPort, info.DstPort) + } + + if !info.IsUDP { + t.Error("Expected UDP packet") + } + + if info.Protocol != 17 { + t.Errorf("Expected protocol 17 (UDP), got %d", info.Protocol) + } +} + +// TestUDPResponsePacketConstruction validates packet building +func TestUDPResponsePacketConstruction(t *testing.T) { + // This would test the buildUDPResponse function + // For now, it's internal to NetstackHandler + // You could expose it or test via the full handler +} + +// Benchmark packet filtering performance +func BenchmarkIPFilterPassthrough(b *testing.B) { + filter := tunfilter.NewIPFilter() + packet := buildTestPacket( + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("192.168.1.2"), + 12345, + 80, + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + filter.FilterOutbound(packet, len(packet)) + } +} + +func BenchmarkIPFilterWithIntercept(b *testing.B) { + filter := tunfilter.NewIPFilter() + + targetIP := netip.MustParseAddr("10.30.30.30") + filter.AddInterceptIP(targetIP, func(p []byte, d tunfilter.Direction) error { + return nil + }) + + packet := buildTestPacket( + netip.MustParseAddr("192.168.1.1"), + netip.MustParseAddr("10.30.30.30"), + 12345, + 51821, + ) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + filter.FilterOutbound(packet, len(packet)) + } +} + +// buildTestPacket creates a minimal UDP/IP packet for testing +func buildTestPacket(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) []byte { + payload := []byte("test payload") + totalLen := 20 + 8 + len(payload) // IP + UDP + payload + packet := make([]byte, totalLen) + + // IP Header + packet[0] = 0x45 // Version 4, IHL 5 + binary.BigEndian.PutUint16(packet[2:4], uint16(totalLen)) + packet[8] = 64 // TTL + packet[9] = 17 // UDP + + srcIPBytes := srcIP.As4() + copy(packet[12:16], srcIPBytes[:]) + + dstIPBytes := dstIP.As4() + copy(packet[16:20], dstIPBytes[:]) + + // IP Checksum (simplified - just set to 0 for testing) + packet[10] = 0 + packet[11] = 0 + + // UDP Header + binary.BigEndian.PutUint16(packet[20:22], srcPort) + binary.BigEndian.PutUint16(packet[22:24], dstPort) + binary.BigEndian.PutUint16(packet[24:26], uint16(8+len(payload))) + binary.BigEndian.PutUint16(packet[26:28], 0) // Checksum + + // Payload + copy(packet[28:], payload) + + return packet +} diff --git a/tunfilter/filtered_device.go b/tunfilter/filtered_device.go new file mode 100644 index 0000000..6197ec6 --- /dev/null +++ b/tunfilter/filtered_device.go @@ -0,0 +1,106 @@ +package tunfilter + +import ( + "sync" + + "golang.zx2c4.com/wireguard/tun" +) + +// FilteredDevice wraps a TUN device with packet filtering capabilities +// This sits between WireGuard and the TUN device, intercepting packets in both directions +type FilteredDevice struct { + tun.Device + filter PacketFilter + mutex sync.RWMutex +} + +// NewFilteredDevice creates a new filtered TUN device wrapper +func NewFilteredDevice(device tun.Device, filter PacketFilter) *FilteredDevice { + return &FilteredDevice{ + Device: device, + filter: filter, + } +} + +// Read intercepts packets from the TUN device (outbound from tunnel) +// These are decrypted packets coming out of WireGuard going to the host +func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { + n, err = d.Device.Read(bufs, sizes, offset) + if err != nil || n == 0 { + return n, err + } + + d.mutex.RLock() + filter := d.filter + d.mutex.RUnlock() + + if filter == nil { + return n, err + } + + // Filter packets in place to avoid allocations + // Process from the end to avoid index issues when removing + kept := 0 + for i := 0; i < n; i++ { + packet := bufs[i][offset : offset+sizes[i]] + + // FilterInbound: packet coming FROM tunnel TO host + if action := filter.FilterInbound(packet, sizes[i]); action == FilterActionPass { + // Keep this packet - move it to the "kept" position if needed + if kept != i { + bufs[kept] = bufs[i] + sizes[kept] = sizes[i] + } + kept++ + } + // FilterActionDrop or FilterActionIntercept: don't increment kept + } + + return kept, err +} + +// Write intercepts packets going to the TUN device (inbound to tunnel) +// These are packets from the host going into WireGuard for encryption +func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { + d.mutex.RLock() + filter := d.filter + d.mutex.RUnlock() + + if filter == nil { + return d.Device.Write(bufs, offset) + } + + // Pre-allocate with capacity to avoid most allocations + filteredBufs := make([][]byte, 0, len(bufs)) + intercepted := 0 + + for _, buf := range bufs { + size := len(buf) - offset + packet := buf[offset:] + + // FilterOutbound: packet going FROM host TO tunnel + if action := filter.FilterOutbound(packet, size); action == FilterActionPass { + filteredBufs = append(filteredBufs, buf) + } else { + // Packet was dropped or intercepted + intercepted++ + } + } + + if len(filteredBufs) == 0 { + // All packets were intercepted/dropped + return len(bufs), nil + } + + n, err := d.Device.Write(filteredBufs, offset) + // Add back the intercepted count so WireGuard thinks all packets were processed + n += intercepted + return n, err +} + +// SetFilter updates the packet filter (thread-safe) +func (d *FilteredDevice) SetFilter(filter PacketFilter) { + d.mutex.Lock() + d.filter = filter + d.mutex.Unlock() +} diff --git a/tunfilter/injector.go b/tunfilter/injector.go new file mode 100644 index 0000000..55ca057 --- /dev/null +++ b/tunfilter/injector.go @@ -0,0 +1,69 @@ +package tunfilter + +import ( + "fmt" + "sync" + + "golang.zx2c4.com/wireguard/tun" +) + +// PacketInjector allows interceptors to inject packets back into the TUN device +// This is useful for sending response packets or injecting traffic +type PacketInjector struct { + device tun.Device + mutex sync.RWMutex +} + +// NewPacketInjector creates a new packet injector +func NewPacketInjector(device tun.Device) *PacketInjector { + return &PacketInjector{ + device: device, + } +} + +// InjectInbound injects a packet as if it came from the tunnel (to the host) +// This writes the packet to the TUN device so it appears as incoming traffic +func (p *PacketInjector) InjectInbound(packet []byte) error { + p.mutex.RLock() + device := p.device + p.mutex.RUnlock() + + if device == nil { + return fmt.Errorf("device not set") + } + + // TUN device expects packets in a specific format + // We need to write to the device with the proper offset + const offset = 4 // Standard TUN offset for packet info + + // Create buffer with offset + buf := make([]byte, offset+len(packet)) + copy(buf[offset:], packet) + + // Write packet + bufs := [][]byte{buf} + n, err := device.Write(bufs, offset) + if err != nil { + return fmt.Errorf("failed to inject packet: %w", err) + } + + if n != 1 { + return fmt.Errorf("expected to write 1 packet, wrote %d", n) + } + + return nil +} + +// Stop cleans up the injector +func (p *PacketInjector) Stop() { + p.mutex.Lock() + defer p.mutex.Unlock() + p.device = nil +} + +// SetDevice updates the underlying TUN device +func (p *PacketInjector) SetDevice(device tun.Device) { + p.mutex.Lock() + defer p.mutex.Unlock() + p.device = device +} diff --git a/tunfilter/interceptor.go b/tunfilter/interceptor.go new file mode 100644 index 0000000..6a03965 --- /dev/null +++ b/tunfilter/interceptor.go @@ -0,0 +1,140 @@ +package tunfilter + +import ( + "context" + "sync" +) + +// PacketInterceptor is an extensible interface for intercepting and handling packets +// before they go through the WireGuard tunnel +type PacketInterceptor interface { + // Name returns the interceptor's name for logging/debugging + Name() string + + // ShouldIntercept returns true if this interceptor wants to handle the packet + // This is called for every packet, so it should be fast (just check IP/port) + ShouldIntercept(packet []byte, direction Direction) bool + + // HandlePacket processes an intercepted packet + // The interceptor can: + // - Handle it completely and return nil (packet won't go through tunnel) + // - Return an error if something went wrong + // Context can be used for cancellation + HandlePacket(ctx context.Context, packet []byte, direction Direction) error + + // Start initializes the interceptor (e.g., start listening sockets) + Start(ctx context.Context) error + + // Stop cleanly shuts down the interceptor + Stop() error +} + +// InterceptorManager manages multiple packet interceptors +type InterceptorManager struct { + interceptors []PacketInterceptor + injector *PacketInjector + ctx context.Context + cancel context.CancelFunc + mutex sync.RWMutex +} + +// NewInterceptorManager creates a new interceptor manager +func NewInterceptorManager(injector *PacketInjector) *InterceptorManager { + ctx, cancel := context.WithCancel(context.Background()) + return &InterceptorManager{ + interceptors: make([]PacketInterceptor, 0), + injector: injector, + ctx: ctx, + cancel: cancel, + } +} + +// AddInterceptor adds a new interceptor to the manager +func (m *InterceptorManager) AddInterceptor(interceptor PacketInterceptor) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + m.interceptors = append(m.interceptors, interceptor) + + // Start the interceptor + if err := interceptor.Start(m.ctx); err != nil { + return err + } + + return nil +} + +// RemoveInterceptor removes an interceptor by name +func (m *InterceptorManager) RemoveInterceptor(name string) error { + m.mutex.Lock() + defer m.mutex.Unlock() + + for i, interceptor := range m.interceptors { + if interceptor.Name() == name { + // Stop the interceptor + if err := interceptor.Stop(); err != nil { + return err + } + + // Remove from slice + m.interceptors = append(m.interceptors[:i], m.interceptors[i+1:]...) + return nil + } + } + + return nil +} + +// HandlePacket is called by the filter for each packet +// It checks all interceptors in order and lets the first matching one handle it +func (m *InterceptorManager) HandlePacket(packet []byte, direction Direction) FilterAction { + m.mutex.RLock() + interceptors := m.interceptors + m.mutex.RUnlock() + + // Try each interceptor in order + for _, interceptor := range interceptors { + if interceptor.ShouldIntercept(packet, direction) { + // Make a copy to avoid data races + packetCopy := make([]byte, len(packet)) + copy(packetCopy, packet) + + // Handle in background to avoid blocking packet processing + go func(ic PacketInterceptor, pkt []byte) { + if err := ic.HandlePacket(m.ctx, pkt, direction); err != nil { + // Log error but don't fail + // TODO: Add proper logging + } + }(interceptor, packetCopy) + + // Packet was intercepted + return FilterActionIntercept + } + } + + // No interceptor wanted this packet + return FilterActionPass +} + +// Stop stops all interceptors +func (m *InterceptorManager) Stop() error { + m.cancel() + + m.mutex.Lock() + defer m.mutex.Unlock() + + var lastErr error + for _, interceptor := range m.interceptors { + if err := interceptor.Stop(); err != nil { + lastErr = err + } + } + + m.interceptors = nil + return lastErr +} + +// GetInjector returns the packet injector for interceptors to use +func (m *InterceptorManager) GetInjector() *PacketInjector { + return m.injector +} diff --git a/tunfilter/interceptor_filter.go b/tunfilter/interceptor_filter.go new file mode 100644 index 0000000..a2de341 --- /dev/null +++ b/tunfilter/interceptor_filter.go @@ -0,0 +1,30 @@ +package tunfilter + +// InterceptorFilter is a PacketFilter that uses an InterceptorManager +// This allows the filtered device to work with the new interceptor system +type InterceptorFilter struct { + manager *InterceptorManager +} + +// NewInterceptorFilter creates a new filter that uses an interceptor manager +func NewInterceptorFilter(manager *InterceptorManager) *InterceptorFilter { + return &InterceptorFilter{ + manager: manager, + } +} + +// FilterOutbound checks all interceptors for outbound packets +func (f *InterceptorFilter) FilterOutbound(packet []byte, size int) FilterAction { + if f.manager == nil { + return FilterActionPass + } + return f.manager.HandlePacket(packet, DirectionOutbound) +} + +// FilterInbound checks all interceptors for inbound packets +func (f *InterceptorFilter) FilterInbound(packet []byte, size int) FilterAction { + if f.manager == nil { + return FilterActionPass + } + return f.manager.HandlePacket(packet, DirectionInbound) +} diff --git a/tunfilter/ipfilter.go b/tunfilter/ipfilter.go new file mode 100644 index 0000000..95dbecc --- /dev/null +++ b/tunfilter/ipfilter.go @@ -0,0 +1,194 @@ +package tunfilter + +import ( + "encoding/binary" + "net/netip" + "sync" +) + +// IPFilter provides fast IP-based packet filtering and interception +type IPFilter struct { + // Map of IP addresses to intercept (for O(1) lookup) + interceptIPs map[netip.Addr]HandlerFunc + mutex sync.RWMutex +} + +// NewIPFilter creates a new IP-based packet filter +func NewIPFilter() *IPFilter { + return &IPFilter{ + interceptIPs: make(map[netip.Addr]HandlerFunc), + } +} + +// AddInterceptIP adds an IP address to intercept +// All packets to/from this IP will be passed to the handler function +func (f *IPFilter) AddInterceptIP(ip netip.Addr, handler HandlerFunc) { + f.mutex.Lock() + defer f.mutex.Unlock() + f.interceptIPs[ip] = handler +} + +// RemoveInterceptIP removes an IP from interception +func (f *IPFilter) RemoveInterceptIP(ip netip.Addr) { + f.mutex.Lock() + defer f.mutex.Unlock() + delete(f.interceptIPs, ip) +} + +// FilterOutbound filters packets going from host to tunnel +func (f *IPFilter) FilterOutbound(packet []byte, size int) FilterAction { + // Fast path: no interceptors configured + f.mutex.RLock() + hasInterceptors := len(f.interceptIPs) > 0 + f.mutex.RUnlock() + + if !hasInterceptors { + return FilterActionPass + } + + // Parse IP header (minimum 20 bytes) + if size < 20 { + return FilterActionPass + } + + // Check IP version (IPv4 only for now) + version := packet[0] >> 4 + if version != 4 { + return FilterActionPass + } + + // Extract destination IP (bytes 16-20 in IPv4 header) + dstIP, ok := netip.AddrFromSlice(packet[16:20]) + if !ok { + return FilterActionPass + } + + // Check if this IP should be intercepted + f.mutex.RLock() + handler, shouldIntercept := f.interceptIPs[dstIP] + f.mutex.RUnlock() + + if shouldIntercept && handler != nil { + // Make a copy of the packet for the handler (to avoid data races) + packetCopy := make([]byte, size) + copy(packetCopy, packet[:size]) + + // Call handler in background to avoid blocking packet processing + go handler(packetCopy, DirectionOutbound) + + // Intercept the packet (don't send it through the tunnel) + return FilterActionIntercept + } + + return FilterActionPass +} + +// FilterInbound filters packets coming from tunnel to host +func (f *IPFilter) FilterInbound(packet []byte, size int) FilterAction { + // Fast path: no interceptors configured + f.mutex.RLock() + hasInterceptors := len(f.interceptIPs) > 0 + f.mutex.RUnlock() + + if !hasInterceptors { + return FilterActionPass + } + + // Parse IP header (minimum 20 bytes) + if size < 20 { + return FilterActionPass + } + + // Check IP version (IPv4 only for now) + version := packet[0] >> 4 + if version != 4 { + return FilterActionPass + } + + // Extract source IP (bytes 12-16 in IPv4 header) + srcIP, ok := netip.AddrFromSlice(packet[12:16]) + if !ok { + return FilterActionPass + } + + // Check if this IP should be intercepted + f.mutex.RLock() + handler, shouldIntercept := f.interceptIPs[srcIP] + f.mutex.RUnlock() + + if shouldIntercept && handler != nil { + // Make a copy of the packet for the handler + packetCopy := make([]byte, size) + copy(packetCopy, packet[:size]) + + // Call handler in background + go handler(packetCopy, DirectionInbound) + + // Intercept the packet (don't deliver to host) + return FilterActionIntercept + } + + return FilterActionPass +} + +// ParsePacketInfo extracts useful information from a packet for debugging/logging +type PacketInfo struct { + Version uint8 + Protocol uint8 + SrcIP netip.Addr + DstIP netip.Addr + SrcPort uint16 + DstPort uint16 + IsUDP bool + IsTCP bool + PayloadLen int +} + +// ParsePacket extracts packet information (useful for handlers) +func ParsePacket(packet []byte) (*PacketInfo, bool) { + if len(packet) < 20 { + return nil, false + } + + info := &PacketInfo{} + + // IP version + info.Version = packet[0] >> 4 + if info.Version != 4 { + return nil, false + } + + // Protocol + info.Protocol = packet[9] + info.IsUDP = info.Protocol == 17 + info.IsTCP = info.Protocol == 6 + + // Source and destination IPs + if srcIP, ok := netip.AddrFromSlice(packet[12:16]); ok { + info.SrcIP = srcIP + } + if dstIP, ok := netip.AddrFromSlice(packet[16:20]); ok { + info.DstIP = dstIP + } + + // Get IP header length + ihl := int(packet[0]&0x0f) * 4 + if len(packet) < ihl { + return info, true + } + + // Extract ports for TCP/UDP + if (info.IsUDP || info.IsTCP) && len(packet) >= ihl+4 { + info.SrcPort = binary.BigEndian.Uint16(packet[ihl : ihl+2]) + info.DstPort = binary.BigEndian.Uint16(packet[ihl+2 : ihl+4]) + } + + // Payload length + totalLen := binary.BigEndian.Uint16(packet[2:4]) + info.PayloadLen = int(totalLen) - ihl + if info.IsUDP || info.IsTCP { + info.PayloadLen -= 8 // UDP header size + } + + return info, true +}