Reorg the files

Former-commit-id: 5505c1d2c7
This commit is contained in:
Owen
2025-11-21 16:53:54 -05:00
parent 7941479994
commit d7cd746cc9
4 changed files with 35 additions and 76 deletions

190
device/middle_device.go Normal file
View File

@@ -0,0 +1,190 @@
package device
import (
"net/netip"
"sync"
"golang.zx2c4.com/wireguard/tun"
)
// PacketHandler processes intercepted packets and returns true if packet should be dropped
type PacketHandler func(packet []byte) bool
// FilterRule defines a rule for packet filtering
type FilterRule struct {
DestIP netip.Addr
Handler PacketHandler
}
// MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct {
tun.Device
rules []FilterRule
mutex sync.RWMutex
}
// NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice {
return &MiddleDevice{
Device: device,
rules: make([]FilterRule, 0),
}
}
// AddRule adds a packet filtering rule
func (d *MiddleDevice) AddRule(destIP netip.Addr, handler PacketHandler) {
d.mutex.Lock()
defer d.mutex.Unlock()
d.rules = append(d.rules, FilterRule{
DestIP: destIP,
Handler: handler,
})
}
// RemoveRule removes all rules for a given destination IP
func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.mutex.Lock()
defer d.mutex.Unlock()
newRules := make([]FilterRule, 0, len(d.rules))
for _, rule := range d.rules {
if rule.DestIP != destIP {
newRules = append(newRules, rule)
}
}
d.rules = newRules
}
// extractDestIP extracts destination IP from packet (fast path)
func extractDestIP(packet []byte) (netip.Addr, bool) {
if len(packet) < 20 {
return netip.Addr{}, false
}
version := packet[0] >> 4
switch version {
case 4:
if len(packet) < 20 {
return netip.Addr{}, false
}
// Destination IP is at bytes 16-19 for IPv4
ip := netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
return ip, true
case 6:
if len(packet) < 40 {
return netip.Addr{}, false
}
// Destination IP is at bytes 24-39 for IPv6
var ip16 [16]byte
copy(ip16[:], packet[24:40])
ip := netip.AddrFrom16(ip16)
return ip, true
}
return netip.Addr{}, false
}
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
n, err = d.Device.Read(bufs, sizes, offset)
if err != nil || n == 0 {
return n, err
}
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return n, err
}
// Process packets and filter out handled ones
writeIdx := 0
for readIdx := 0; readIdx < n; readIdx++ {
packet := bufs[readIdx][offset : offset+sizes[readIdx]]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
}
}
if !handled {
// Keep packet
if writeIdx != readIdx {
bufs[writeIdx] = bufs[readIdx]
sizes[writeIdx] = sizes[readIdx]
}
writeIdx++
}
}
return writeIdx, err
}
// Write intercepts packets going DOWN to the TUN device (from WireGuard)
func (d *MiddleDevice) Write(bufs [][]byte, offset int) (int, error) {
d.mutex.RLock()
rules := d.rules
d.mutex.RUnlock()
if len(rules) == 0 {
return d.Device.Write(bufs, offset)
}
// Filter packets going down
filteredBufs := make([][]byte, 0, len(bufs))
for _, buf := range bufs {
if len(buf) <= offset {
continue
}
packet := buf[offset:]
destIP, ok := extractDestIP(packet)
if !ok {
// Can't parse, keep packet
filteredBufs = append(filteredBufs, buf)
continue
}
// Check if packet matches any rule
handled := false
for _, rule := range rules {
if rule.DestIP == destIP {
if rule.Handler(packet) {
// Packet was handled and should be dropped
handled = true
break
}
}
}
if !handled {
filteredBufs = append(filteredBufs, buf)
}
}
if len(filteredBufs) == 0 {
return len(bufs), nil // All packets were handled
}
return d.Device.Write(filteredBufs, offset)
}

View File

@@ -0,0 +1,102 @@
package device
import (
"net/netip"
"testing"
"github.com/fosrl/newt/util"
)
func TestExtractDestIP(t *testing.T) {
tests := []struct {
name string
packet []byte
wantIP string
wantOk bool
}{
{
name: "IPv4 packet",
packet: []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
0x0a, 0x1e, 0x1e, 0x1e, // Dest IP: 10.30.30.30
},
wantIP: "10.30.30.30",
wantOk: true,
},
{
name: "Too short packet",
packet: []byte{0x45, 0x00},
wantIP: "",
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotIP, gotOk := extractDestIP(tt.packet)
if gotOk != tt.wantOk {
t.Errorf("extractDestIP() ok = %v, want %v", gotOk, tt.wantOk)
return
}
if tt.wantOk {
wantAddr := netip.MustParseAddr(tt.wantIP)
if gotIP != wantAddr {
t.Errorf("extractDestIP() ip = %v, want %v", gotIP, wantAddr)
}
}
})
}
}
func TestGetProtocol(t *testing.T) {
tests := []struct {
name string
packet []byte
wantProto uint8
wantOk bool
}{
{
name: "UDP packet",
packet: []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01, // Protocol: UDP (17) at byte 9
0x0a, 0x1e, 0x1e, 0x1e,
},
wantProto: 17,
wantOk: true,
},
{
name: "Too short",
packet: []byte{0x45, 0x00},
wantProto: 0,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotProto, gotOk := util.GetProtocol(tt.packet)
if gotOk != tt.wantOk {
t.Errorf("GetProtocol() ok = %v, want %v", gotOk, tt.wantOk)
return
}
if gotProto != tt.wantProto {
t.Errorf("GetProtocol() proto = %v, want %v", gotProto, tt.wantProto)
}
})
}
}
func BenchmarkExtractDestIP(b *testing.B) {
packet := []byte{
0x45, 0x00, 0x00, 0x54, 0x00, 0x00, 0x40, 0x00,
0x40, 0x11, 0x00, 0x00, 0xc0, 0xa8, 0x01, 0x01,
0x0a, 0x1e, 0x1e, 0x1e,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
extractDestIP(packet)
}
}