First try

This commit is contained in:
Owen
2025-11-21 11:59:44 -05:00
parent a3e34f3cc0
commit f882cd983b
11 changed files with 996 additions and 6 deletions

8
go.mod
View File

@@ -7,19 +7,21 @@ require (
github.com/fosrl/newt v0.0.0
github.com/gorilla/websocket v1.5.3
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.44.0
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6
golang.org/x/sys v0.38.0
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
gvisor.dev/gvisor v0.0.0-20251121015435-2879878b845e
software.sslmate.com/src/go-pkcs12 v0.6.0
)
require (
github.com/google/btree v1.1.3 // indirect
github.com/vishvananda/netns v0.0.5 // indirect
golang.org/x/crypto v0.44.0 // indirect
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/time v0.12.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 // indirect
)
replace github.com/fosrl/newt => ../newt

6
go.sum
View File

@@ -14,6 +14,8 @@ golang.org/x/crypto v0.44.0 h1:A97SsFvM3AIwEEmTBiaxPPTYpDC47w720rdiiUvgoAU=
golang.org/x/crypto v0.44.0/go.mod h1:013i+Nw79BMiQiMsOPcVCB5ZIJbYkerPrGnOa00tvmc=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6 h1:zfMcR1Cs4KNuomFFgGefv5N0czO2XZpUbxGUy8i8ug0=
golang.org/x/exp v0.0.0-20251113190631-e25ba8c21ef6/go.mod h1:46edojNIoXTNOhySWIWdix628clX9ODXwPsQuG6hsK0=
golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -28,7 +30,7 @@ golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb h1:whnFRlWMcXI9d+Z
golang.zx2c4.com/wireguard v0.0.0-20250521234502-f333402bd9cb/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56 h1:H+qymc2ndLKNFR5TcaPmsHGiJnhJMqeofBYSRq4oG3c=
gvisor.dev/gvisor v0.0.0-20250718192347-d7830d968c56/go.mod h1:i8iCZyAdwRnLZYaIi2NUL1gfNtAveqxkKAe0JfAv9Bs=
gvisor.dev/gvisor v0.0.0-20251121015435-2879878b845e h1:upyNwibTehzZl2FY2LEQ6bTRKOrU0IMiBLiIKT+dKF0=
gvisor.dev/gvisor v0.0.0-20251121015435-2879878b845e/go.mod h1:W1ZgZ/Dh85TgSZWH67l2jKVpDE5bjIaut7rjwwOiHzQ=
software.sslmate.com/src/go-pkcs12 v0.6.0 h1:f3sQittAeF+pao32Vb+mkli+ZyT+VwKaD014qFGq6oU=
software.sslmate.com/src/go-pkcs12 v0.6.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI=

View File

@@ -15,6 +15,7 @@ import (
"github.com/fosrl/olm/api"
"github.com/fosrl/olm/network"
"github.com/fosrl/olm/peermonitor"
"github.com/fosrl/olm/tunfilter"
"github.com/fosrl/olm/websocket"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@@ -81,6 +82,12 @@ var (
globalCtx context.Context
stopRegister func()
stopPing chan struct{}
// Packet interceptor components
filteredDev *tunfilter.FilteredDevice
packetInjector *tunfilter.PacketInjector
interceptorManager *tunfilter.InterceptorManager
ipFilter *tunfilter.IPFilter
)
func Init(ctx context.Context, config GlobalConfig) {
@@ -424,6 +431,16 @@ func StartTunnel(config TunnelConfig) {
}
}
// Create packet injector for the TUN device
packetInjector = tunfilter.NewPacketInjector(tdev)
// Create interceptor manager
interceptorManager = tunfilter.NewInterceptorManager(packetInjector)
// Create an interceptor filter and wrap the TUN device
interceptorFilter := tunfilter.NewInterceptorFilter(interceptorManager)
filteredDev = tunfilter.NewFilteredDevice(tdev, interceptorFilter)
// fileUAPI, err := func() (*os.File, error) {
// if config.FileDescriptorUAPI != 0 {
// fd, err := strconv.ParseUint(fmt.Sprintf("%d", config.FileDescriptorUAPI), 10, 32)
@@ -441,7 +458,8 @@ func StartTunnel(config TunnelConfig) {
// }
wgLogger := logger.GetLogger().GetWireGuardLogger("wireguard: ")
dev = device.NewDevice(tdev, sharedBind, (*device.Logger)(wgLogger))
// Use filtered device instead of raw TUN device
dev = device.NewDevice(filteredDev, sharedBind, (*device.Logger)(wgLogger))
// uapiListener, err = uapiListen(interfaceName, fileUAPI)
// if err != nil {
@@ -1048,6 +1066,26 @@ func Close() {
dev = nil
}
// Stop packet injector
if packetInjector != nil {
packetInjector.Stop()
packetInjector = nil
}
// Stop interceptor manager
if interceptorManager != nil {
interceptorManager.Stop()
interceptorManager = nil
}
// Clear packet filter
if filteredDev != nil {
filteredDev.SetFilter(nil)
filteredDev = nil
}
ipFilter = nil
// Close TUN device
if tdev != nil {
tdev.Close()

215
tunfilter/README.md Normal file
View File

@@ -0,0 +1,215 @@
# TUN Filter Interceptor System
An extensible packet filtering and interception framework for the olm TUN device.
## Architecture
The system consists of several components that work together:
```
┌─────────────────┐
│ WireGuard │
└────────┬────────┘
┌────────▼────────┐
│ FilteredDevice │ (Wraps TUN device)
└────────┬────────┘
┌────────▼──────────────┐
│ InterceptorFilter │
└────────┬──────────────┘
┌────────▼──────────────┐
│ InterceptorManager │
│ ┌─────────────────┐ │
│ │ DNS Proxy │ │
│ ├─────────────────┤ │
│ │ Future... │ │
│ └─────────────────┘ │
└────────┬──────────────┘
┌────────▼────────┐
│ TUN Device │
└─────────────────┘
```
## Components
### FilteredDevice
- Wraps the TUN device
- Calls packet filters for every packet in both directions
- Located between WireGuard and the TUN device
### PacketInterceptor Interface
Extensible interface for creating custom packet interceptors:
```go
type PacketInterceptor interface {
Name() string
ShouldIntercept(packet []byte, direction Direction) bool
HandlePacket(ctx context.Context, packet []byte, direction Direction) error
Start(ctx context.Context) error
Stop() error
}
```
### InterceptorManager
- Manages multiple interceptors
- Routes packets to the first matching interceptor
- Handles lifecycle (start/stop) for all interceptors
### PacketInjector
- Allows interceptors to inject response packets
- Writes packets back into the TUN device as if they came from the tunnel
### DNS Proxy Interceptor
Example implementation that:
- Intercepts DNS queries to `10.30.30.30`
- Forwards them to `8.8.8.8`
- Injects responses back as if they came from `10.30.30.30`
## Usage
The system is automatically initialized in `olm.go` when a tunnel is created:
```go
// Create packet injector for the TUN device
packetInjector = tunfilter.NewPacketInjector(tdev)
// Create interceptor manager
interceptorManager = tunfilter.NewInterceptorManager(packetInjector)
// Add DNS proxy interceptor for 10.30.30.30
dnsProxy := tunfilter.NewDNSProxyInterceptor(
tunfilter.DNSProxyConfig{
Name: "dns-proxy",
InterceptIP: netip.MustParseAddr("10.30.30.30"),
UpstreamDNS: "8.8.8.8:53",
LocalIP: tunnelIP,
},
packetInjector,
)
interceptorManager.AddInterceptor(dnsProxy)
// Create filter and wrap TUN device
interceptorFilter := tunfilter.NewInterceptorFilter(interceptorManager)
filteredDev = tunfilter.NewFilteredDevice(tdev, interceptorFilter)
```
## Adding New Interceptors
To create a new interceptor:
1. **Implement the PacketInterceptor interface:**
```go
type MyInterceptor struct {
name string
injector *tunfilter.PacketInjector
// your fields...
}
func (i *MyInterceptor) Name() string {
return i.name
}
func (i *MyInterceptor) ShouldIntercept(packet []byte, direction tunfilter.Direction) bool {
// Quick check: parse packet and decide if you want to handle it
// This is called for EVERY packet, so make it fast!
info, ok := tunfilter.ParsePacket(packet)
if !ok {
return false
}
// Example: intercept UDP packets to a specific IP and port
return info.IsUDP && info.DstIP == myTargetIP && info.DstPort == myPort
}
func (i *MyInterceptor) HandlePacket(ctx context.Context, packet []byte, direction tunfilter.Direction) error {
// Process the packet
// You can:
// 1. Extract data from it
// 2. Make external requests
// 3. Inject response packets using i.injector.InjectInbound(responsePacket)
return nil
}
func (i *MyInterceptor) Start(ctx context.Context) error {
// Initialize resources (e.g., start listeners, connect to services)
return nil
}
func (i *MyInterceptor) Stop() error {
// Clean up resources
return nil
}
```
2. **Register it with the manager:**
```go
myInterceptor := NewMyInterceptor(...)
if err := interceptorManager.AddInterceptor(myInterceptor); err != nil {
logger.Error("Failed to add interceptor: %v", err)
}
```
## Packet Flow
### Outbound (Host → Tunnel)
1. Packet written by application
2. TUN device receives it
3. FilteredDevice.Write intercepts it
4. InterceptorFilter checks all interceptors
5. If intercepted: Handler processes it, returns FilterActionIntercept
6. If passed: Packet continues to WireGuard for encryption
### Inbound (Tunnel → Host)
1. WireGuard decrypts packet
2. FilteredDevice.Read intercepts it
3. InterceptorFilter checks all interceptors
4. If intercepted: Handler processes it, returns FilterActionIntercept
5. If passed: Packet written to TUN device for delivery to host
## Example: DNS Proxy
DNS queries to `10.30.30.30:53` are intercepted:
```
Application → 10.30.30.30:53
DNSProxyInterceptor
Forward to 8.8.8.8:53
Get response
Build response packet (src: 10.30.30.30)
Inject into TUN device
Application receives response
```
All other traffic flows normally through the WireGuard tunnel.
## Future Ideas
The interceptor system can be extended for:
- **HTTP Proxy**: Intercept HTTP traffic and route through a proxy
- **Protocol Translation**: Convert one protocol to another
- **Traffic Shaping**: Add delays, simulate packet loss
- **Logging/Monitoring**: Record specific traffic patterns
- **Custom DNS Rules**: Different upstream servers based on domain
- **Local Service Integration**: Route certain IPs to local services
- **mDNS Support**: Handle multicast DNS queries locally
## Performance Notes
- `ShouldIntercept()` is called for every packet - keep it fast!
- Use simple checks (IP/port comparisons)
- Avoid allocations in the hot path
- Packet handling runs in a goroutine to avoid blocking
- The filtered device uses zero-copy techniques where possible

35
tunfilter/filter.go Normal file
View File

@@ -0,0 +1,35 @@
package tunfilter
// FilterAction defines what to do with a packet
type FilterAction int
const (
// FilterActionPass allows the packet to continue normally
FilterActionPass FilterAction = iota
// FilterActionDrop silently drops the packet
FilterActionDrop
// FilterActionIntercept captures the packet for custom handling
FilterActionIntercept
)
// PacketFilter interface for filtering and intercepting packets
type PacketFilter interface {
// FilterOutbound filters packets going FROM host TO tunnel (before encryption)
// Return FilterActionPass to allow, FilterActionDrop to drop, FilterActionIntercept to handle
FilterOutbound(packet []byte, size int) FilterAction
// FilterInbound filters packets coming FROM tunnel TO host (after decryption)
// Return FilterActionPass to allow, FilterActionDrop to drop, FilterActionIntercept to handle
FilterInbound(packet []byte, size int) FilterAction
}
// HandlerFunc is called when a packet is intercepted
type HandlerFunc func(packet []byte, direction Direction) error
// Direction indicates packet flow direction
type Direction int
const (
DirectionOutbound Direction = iota // Host -> Tunnel
DirectionInbound // Tunnel -> Host
)

159
tunfilter/filter_test.go Normal file
View File

@@ -0,0 +1,159 @@
package tunfilter_test
import (
"encoding/binary"
"net/netip"
"testing"
"github.com/fosrl/olm/tunfilter"
)
// TestIPFilter validates the IP-based packet filtering
func TestIPFilter(t *testing.T) {
filter := tunfilter.NewIPFilter()
// Create a test handler that just tracks calls
handler := func(packet []byte, direction tunfilter.Direction) error {
return nil
}
// Add IP to intercept
targetIP := netip.MustParseAddr("10.30.30.30")
filter.AddInterceptIP(targetIP, handler)
// Create a test packet destined for 10.30.30.30
packet := buildTestPacket(
netip.MustParseAddr("192.168.1.1"),
netip.MustParseAddr("10.30.30.30"),
12345,
51821,
)
// Filter the packet (outbound direction)
action := filter.FilterOutbound(packet, len(packet))
// Should be intercepted
if action != tunfilter.FilterActionIntercept {
t.Errorf("Expected FilterActionIntercept, got %v", action)
}
// Handler should eventually be called (async)
// In real tests you'd use sync primitives
}
// TestPacketParsing validates packet information extraction
func TestPacketParsing(t *testing.T) {
srcIP := netip.MustParseAddr("192.168.1.100")
dstIP := netip.MustParseAddr("10.30.30.30")
srcPort := uint16(54321)
dstPort := uint16(51821)
packet := buildTestPacket(srcIP, dstIP, srcPort, dstPort)
info, ok := tunfilter.ParsePacket(packet)
if !ok {
t.Fatal("Failed to parse packet")
}
if info.SrcIP != srcIP {
t.Errorf("Expected src IP %s, got %s", srcIP, info.SrcIP)
}
if info.DstIP != dstIP {
t.Errorf("Expected dst IP %s, got %s", dstIP, info.DstIP)
}
if info.SrcPort != srcPort {
t.Errorf("Expected src port %d, got %d", srcPort, info.SrcPort)
}
if info.DstPort != dstPort {
t.Errorf("Expected dst port %d, got %d", dstPort, info.DstPort)
}
if !info.IsUDP {
t.Error("Expected UDP packet")
}
if info.Protocol != 17 {
t.Errorf("Expected protocol 17 (UDP), got %d", info.Protocol)
}
}
// TestUDPResponsePacketConstruction validates packet building
func TestUDPResponsePacketConstruction(t *testing.T) {
// This would test the buildUDPResponse function
// For now, it's internal to NetstackHandler
// You could expose it or test via the full handler
}
// Benchmark packet filtering performance
func BenchmarkIPFilterPassthrough(b *testing.B) {
filter := tunfilter.NewIPFilter()
packet := buildTestPacket(
netip.MustParseAddr("192.168.1.1"),
netip.MustParseAddr("192.168.1.2"),
12345,
80,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterOutbound(packet, len(packet))
}
}
func BenchmarkIPFilterWithIntercept(b *testing.B) {
filter := tunfilter.NewIPFilter()
targetIP := netip.MustParseAddr("10.30.30.30")
filter.AddInterceptIP(targetIP, func(p []byte, d tunfilter.Direction) error {
return nil
})
packet := buildTestPacket(
netip.MustParseAddr("192.168.1.1"),
netip.MustParseAddr("10.30.30.30"),
12345,
51821,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
filter.FilterOutbound(packet, len(packet))
}
}
// buildTestPacket creates a minimal UDP/IP packet for testing
func buildTestPacket(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) []byte {
payload := []byte("test payload")
totalLen := 20 + 8 + len(payload) // IP + UDP + payload
packet := make([]byte, totalLen)
// IP Header
packet[0] = 0x45 // Version 4, IHL 5
binary.BigEndian.PutUint16(packet[2:4], uint16(totalLen))
packet[8] = 64 // TTL
packet[9] = 17 // UDP
srcIPBytes := srcIP.As4()
copy(packet[12:16], srcIPBytes[:])
dstIPBytes := dstIP.As4()
copy(packet[16:20], dstIPBytes[:])
// IP Checksum (simplified - just set to 0 for testing)
packet[10] = 0
packet[11] = 0
// UDP Header
binary.BigEndian.PutUint16(packet[20:22], srcPort)
binary.BigEndian.PutUint16(packet[22:24], dstPort)
binary.BigEndian.PutUint16(packet[24:26], uint16(8+len(payload)))
binary.BigEndian.PutUint16(packet[26:28], 0) // Checksum
// Payload
copy(packet[28:], payload)
return packet
}

View File

@@ -0,0 +1,106 @@
package tunfilter
import (
"sync"
"golang.zx2c4.com/wireguard/tun"
)
// FilteredDevice wraps a TUN device with packet filtering capabilities
// This sits between WireGuard and the TUN device, intercepting packets in both directions
type FilteredDevice struct {
tun.Device
filter PacketFilter
mutex sync.RWMutex
}
// NewFilteredDevice creates a new filtered TUN device wrapper
func NewFilteredDevice(device tun.Device, filter PacketFilter) *FilteredDevice {
return &FilteredDevice{
Device: device,
filter: filter,
}
}
// Read intercepts packets from the TUN device (outbound from tunnel)
// These are decrypted packets coming out of WireGuard going to the host
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
n, err = d.Device.Read(bufs, sizes, offset)
if err != nil || n == 0 {
return n, err
}
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return n, err
}
// Filter packets in place to avoid allocations
// Process from the end to avoid index issues when removing
kept := 0
for i := 0; i < n; i++ {
packet := bufs[i][offset : offset+sizes[i]]
// FilterInbound: packet coming FROM tunnel TO host
if action := filter.FilterInbound(packet, sizes[i]); action == FilterActionPass {
// Keep this packet - move it to the "kept" position if needed
if kept != i {
bufs[kept] = bufs[i]
sizes[kept] = sizes[i]
}
kept++
}
// FilterActionDrop or FilterActionIntercept: don't increment kept
}
return kept, err
}
// Write intercepts packets going to the TUN device (inbound to tunnel)
// These are packets from the host going into WireGuard for encryption
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter == nil {
return d.Device.Write(bufs, offset)
}
// Pre-allocate with capacity to avoid most allocations
filteredBufs := make([][]byte, 0, len(bufs))
intercepted := 0
for _, buf := range bufs {
size := len(buf) - offset
packet := buf[offset:]
// FilterOutbound: packet going FROM host TO tunnel
if action := filter.FilterOutbound(packet, size); action == FilterActionPass {
filteredBufs = append(filteredBufs, buf)
} else {
// Packet was dropped or intercepted
intercepted++
}
}
if len(filteredBufs) == 0 {
// All packets were intercepted/dropped
return len(bufs), nil
}
n, err := d.Device.Write(filteredBufs, offset)
// Add back the intercepted count so WireGuard thinks all packets were processed
n += intercepted
return n, err
}
// SetFilter updates the packet filter (thread-safe)
func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.mutex.Lock()
d.filter = filter
d.mutex.Unlock()
}

69
tunfilter/injector.go Normal file
View File

@@ -0,0 +1,69 @@
package tunfilter
import (
"fmt"
"sync"
"golang.zx2c4.com/wireguard/tun"
)
// PacketInjector allows interceptors to inject packets back into the TUN device
// This is useful for sending response packets or injecting traffic
type PacketInjector struct {
device tun.Device
mutex sync.RWMutex
}
// NewPacketInjector creates a new packet injector
func NewPacketInjector(device tun.Device) *PacketInjector {
return &PacketInjector{
device: device,
}
}
// InjectInbound injects a packet as if it came from the tunnel (to the host)
// This writes the packet to the TUN device so it appears as incoming traffic
func (p *PacketInjector) InjectInbound(packet []byte) error {
p.mutex.RLock()
device := p.device
p.mutex.RUnlock()
if device == nil {
return fmt.Errorf("device not set")
}
// TUN device expects packets in a specific format
// We need to write to the device with the proper offset
const offset = 4 // Standard TUN offset for packet info
// Create buffer with offset
buf := make([]byte, offset+len(packet))
copy(buf[offset:], packet)
// Write packet
bufs := [][]byte{buf}
n, err := device.Write(bufs, offset)
if err != nil {
return fmt.Errorf("failed to inject packet: %w", err)
}
if n != 1 {
return fmt.Errorf("expected to write 1 packet, wrote %d", n)
}
return nil
}
// Stop cleans up the injector
func (p *PacketInjector) Stop() {
p.mutex.Lock()
defer p.mutex.Unlock()
p.device = nil
}
// SetDevice updates the underlying TUN device
func (p *PacketInjector) SetDevice(device tun.Device) {
p.mutex.Lock()
defer p.mutex.Unlock()
p.device = device
}

140
tunfilter/interceptor.go Normal file
View File

@@ -0,0 +1,140 @@
package tunfilter
import (
"context"
"sync"
)
// PacketInterceptor is an extensible interface for intercepting and handling packets
// before they go through the WireGuard tunnel
type PacketInterceptor interface {
// Name returns the interceptor's name for logging/debugging
Name() string
// ShouldIntercept returns true if this interceptor wants to handle the packet
// This is called for every packet, so it should be fast (just check IP/port)
ShouldIntercept(packet []byte, direction Direction) bool
// HandlePacket processes an intercepted packet
// The interceptor can:
// - Handle it completely and return nil (packet won't go through tunnel)
// - Return an error if something went wrong
// Context can be used for cancellation
HandlePacket(ctx context.Context, packet []byte, direction Direction) error
// Start initializes the interceptor (e.g., start listening sockets)
Start(ctx context.Context) error
// Stop cleanly shuts down the interceptor
Stop() error
}
// InterceptorManager manages multiple packet interceptors
type InterceptorManager struct {
interceptors []PacketInterceptor
injector *PacketInjector
ctx context.Context
cancel context.CancelFunc
mutex sync.RWMutex
}
// NewInterceptorManager creates a new interceptor manager
func NewInterceptorManager(injector *PacketInjector) *InterceptorManager {
ctx, cancel := context.WithCancel(context.Background())
return &InterceptorManager{
interceptors: make([]PacketInterceptor, 0),
injector: injector,
ctx: ctx,
cancel: cancel,
}
}
// AddInterceptor adds a new interceptor to the manager
func (m *InterceptorManager) AddInterceptor(interceptor PacketInterceptor) error {
m.mutex.Lock()
defer m.mutex.Unlock()
m.interceptors = append(m.interceptors, interceptor)
// Start the interceptor
if err := interceptor.Start(m.ctx); err != nil {
return err
}
return nil
}
// RemoveInterceptor removes an interceptor by name
func (m *InterceptorManager) RemoveInterceptor(name string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
for i, interceptor := range m.interceptors {
if interceptor.Name() == name {
// Stop the interceptor
if err := interceptor.Stop(); err != nil {
return err
}
// Remove from slice
m.interceptors = append(m.interceptors[:i], m.interceptors[i+1:]...)
return nil
}
}
return nil
}
// HandlePacket is called by the filter for each packet
// It checks all interceptors in order and lets the first matching one handle it
func (m *InterceptorManager) HandlePacket(packet []byte, direction Direction) FilterAction {
m.mutex.RLock()
interceptors := m.interceptors
m.mutex.RUnlock()
// Try each interceptor in order
for _, interceptor := range interceptors {
if interceptor.ShouldIntercept(packet, direction) {
// Make a copy to avoid data races
packetCopy := make([]byte, len(packet))
copy(packetCopy, packet)
// Handle in background to avoid blocking packet processing
go func(ic PacketInterceptor, pkt []byte) {
if err := ic.HandlePacket(m.ctx, pkt, direction); err != nil {
// Log error but don't fail
// TODO: Add proper logging
}
}(interceptor, packetCopy)
// Packet was intercepted
return FilterActionIntercept
}
}
// No interceptor wanted this packet
return FilterActionPass
}
// Stop stops all interceptors
func (m *InterceptorManager) Stop() error {
m.cancel()
m.mutex.Lock()
defer m.mutex.Unlock()
var lastErr error
for _, interceptor := range m.interceptors {
if err := interceptor.Stop(); err != nil {
lastErr = err
}
}
m.interceptors = nil
return lastErr
}
// GetInjector returns the packet injector for interceptors to use
func (m *InterceptorManager) GetInjector() *PacketInjector {
return m.injector
}

View File

@@ -0,0 +1,30 @@
package tunfilter
// InterceptorFilter is a PacketFilter that uses an InterceptorManager
// This allows the filtered device to work with the new interceptor system
type InterceptorFilter struct {
manager *InterceptorManager
}
// NewInterceptorFilter creates a new filter that uses an interceptor manager
func NewInterceptorFilter(manager *InterceptorManager) *InterceptorFilter {
return &InterceptorFilter{
manager: manager,
}
}
// FilterOutbound checks all interceptors for outbound packets
func (f *InterceptorFilter) FilterOutbound(packet []byte, size int) FilterAction {
if f.manager == nil {
return FilterActionPass
}
return f.manager.HandlePacket(packet, DirectionOutbound)
}
// FilterInbound checks all interceptors for inbound packets
func (f *InterceptorFilter) FilterInbound(packet []byte, size int) FilterAction {
if f.manager == nil {
return FilterActionPass
}
return f.manager.HandlePacket(packet, DirectionInbound)
}

194
tunfilter/ipfilter.go Normal file
View File

@@ -0,0 +1,194 @@
package tunfilter
import (
"encoding/binary"
"net/netip"
"sync"
)
// IPFilter provides fast IP-based packet filtering and interception
type IPFilter struct {
// Map of IP addresses to intercept (for O(1) lookup)
interceptIPs map[netip.Addr]HandlerFunc
mutex sync.RWMutex
}
// NewIPFilter creates a new IP-based packet filter
func NewIPFilter() *IPFilter {
return &IPFilter{
interceptIPs: make(map[netip.Addr]HandlerFunc),
}
}
// AddInterceptIP adds an IP address to intercept
// All packets to/from this IP will be passed to the handler function
func (f *IPFilter) AddInterceptIP(ip netip.Addr, handler HandlerFunc) {
f.mutex.Lock()
defer f.mutex.Unlock()
f.interceptIPs[ip] = handler
}
// RemoveInterceptIP removes an IP from interception
func (f *IPFilter) RemoveInterceptIP(ip netip.Addr) {
f.mutex.Lock()
defer f.mutex.Unlock()
delete(f.interceptIPs, ip)
}
// FilterOutbound filters packets going from host to tunnel
func (f *IPFilter) FilterOutbound(packet []byte, size int) FilterAction {
// Fast path: no interceptors configured
f.mutex.RLock()
hasInterceptors := len(f.interceptIPs) > 0
f.mutex.RUnlock()
if !hasInterceptors {
return FilterActionPass
}
// Parse IP header (minimum 20 bytes)
if size < 20 {
return FilterActionPass
}
// Check IP version (IPv4 only for now)
version := packet[0] >> 4
if version != 4 {
return FilterActionPass
}
// Extract destination IP (bytes 16-20 in IPv4 header)
dstIP, ok := netip.AddrFromSlice(packet[16:20])
if !ok {
return FilterActionPass
}
// Check if this IP should be intercepted
f.mutex.RLock()
handler, shouldIntercept := f.interceptIPs[dstIP]
f.mutex.RUnlock()
if shouldIntercept && handler != nil {
// Make a copy of the packet for the handler (to avoid data races)
packetCopy := make([]byte, size)
copy(packetCopy, packet[:size])
// Call handler in background to avoid blocking packet processing
go handler(packetCopy, DirectionOutbound)
// Intercept the packet (don't send it through the tunnel)
return FilterActionIntercept
}
return FilterActionPass
}
// FilterInbound filters packets coming from tunnel to host
func (f *IPFilter) FilterInbound(packet []byte, size int) FilterAction {
// Fast path: no interceptors configured
f.mutex.RLock()
hasInterceptors := len(f.interceptIPs) > 0
f.mutex.RUnlock()
if !hasInterceptors {
return FilterActionPass
}
// Parse IP header (minimum 20 bytes)
if size < 20 {
return FilterActionPass
}
// Check IP version (IPv4 only for now)
version := packet[0] >> 4
if version != 4 {
return FilterActionPass
}
// Extract source IP (bytes 12-16 in IPv4 header)
srcIP, ok := netip.AddrFromSlice(packet[12:16])
if !ok {
return FilterActionPass
}
// Check if this IP should be intercepted
f.mutex.RLock()
handler, shouldIntercept := f.interceptIPs[srcIP]
f.mutex.RUnlock()
if shouldIntercept && handler != nil {
// Make a copy of the packet for the handler
packetCopy := make([]byte, size)
copy(packetCopy, packet[:size])
// Call handler in background
go handler(packetCopy, DirectionInbound)
// Intercept the packet (don't deliver to host)
return FilterActionIntercept
}
return FilterActionPass
}
// ParsePacketInfo extracts useful information from a packet for debugging/logging
type PacketInfo struct {
Version uint8
Protocol uint8
SrcIP netip.Addr
DstIP netip.Addr
SrcPort uint16
DstPort uint16
IsUDP bool
IsTCP bool
PayloadLen int
}
// ParsePacket extracts packet information (useful for handlers)
func ParsePacket(packet []byte) (*PacketInfo, bool) {
if len(packet) < 20 {
return nil, false
}
info := &PacketInfo{}
// IP version
info.Version = packet[0] >> 4
if info.Version != 4 {
return nil, false
}
// Protocol
info.Protocol = packet[9]
info.IsUDP = info.Protocol == 17
info.IsTCP = info.Protocol == 6
// Source and destination IPs
if srcIP, ok := netip.AddrFromSlice(packet[12:16]); ok {
info.SrcIP = srcIP
}
if dstIP, ok := netip.AddrFromSlice(packet[16:20]); ok {
info.DstIP = dstIP
}
// Get IP header length
ihl := int(packet[0]&0x0f) * 4
if len(packet) < ihl {
return info, true
}
// Extract ports for TCP/UDP
if (info.IsUDP || info.IsTCP) && len(packet) >= ihl+4 {
info.SrcPort = binary.BigEndian.Uint16(packet[ihl : ihl+2])
info.DstPort = binary.BigEndian.Uint16(packet[ihl+2 : ihl+4])
}
// Payload length
totalLen := binary.BigEndian.Uint16(packet[2:4])
info.PayloadLen = int(totalLen) - ihl
if info.IsUDP || info.IsTCP {
info.PayloadLen -= 8 // UDP header size
}
return info, true
}