Netstack is working

This commit is contained in:
Owen
2025-11-23 16:24:00 -05:00
parent 61c6894f97
commit 4fc751ddbc
4 changed files with 395 additions and 29 deletions

View File

@@ -19,15 +19,73 @@ type FilterRule struct {
// MiddleDevice wraps a TUN device with packet filtering capabilities // MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct { type MiddleDevice struct {
tun.Device tun.Device
rules []FilterRule rules []FilterRule
mutex sync.RWMutex mutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed chan struct{}
}
type readResult struct {
bufs [][]byte
sizes []int
offset int
n int
err error
} }
// NewMiddleDevice creates a new filtered TUN device wrapper // NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice { func NewMiddleDevice(device tun.Device) *MiddleDevice {
return &MiddleDevice{ d := &MiddleDevice{
Device: device, Device: device,
rules: make([]FilterRule, 0), rules: make([]FilterRule, 0),
readCh: make(chan readResult),
injectCh: make(chan []byte, 100),
closed: make(chan struct{}),
}
go d.pump()
return d
}
func (d *MiddleDevice) pump() {
const defaultOffset = 16
batchSize := d.Device.BatchSize()
for {
select {
case <-d.closed:
return
default:
}
// Allocate buffers for reading
// We allocate new buffers for each read to avoid race conditions
// since we pass them to the channel
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for i := range bufs {
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
}
n, err := d.Device.Read(bufs, sizes, defaultOffset)
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-d.closed:
return
}
if err != nil {
return
}
}
}
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
func (d *MiddleDevice) InjectOutbound(packet []byte) {
select {
case d.injectCh <- packet:
case <-d.closed:
} }
} }
@@ -54,6 +112,16 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.rules = newRules d.rules = newRules
} }
// Close stops the device
func (d *MiddleDevice) Close() error {
select {
case <-d.closed:
default:
close(d.closed)
}
return d.Device.Close()
}
// extractDestIP extracts destination IP from packet (fast path) // extractDestIP extracts destination IP from packet (fast path)
func extractDestIP(packet []byte) (netip.Addr, bool) { func extractDestIP(packet []byte) (netip.Addr, bool) {
if len(packet) < 20 { if len(packet) < 20 {
@@ -86,9 +154,49 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
// Read intercepts packets going UP from the TUN device (towards WireGuard) // Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
n, err = d.Device.Read(bufs, sizes, offset) select {
if err != nil || n == 0 { case res := <-d.readCh:
return n, err if res.err != nil {
return 0, res.err
}
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
// Handle offset mismatch if necessary
// We assume the pump used defaultOffset (16)
// If caller asks for different offset, we need to shift
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
// Calculate where the packet data starts and ends in src
pktData := src[srcOffset : srcOffset+srcSize]
// Ensure dest buffer is large enough
if len(bufs[i]) < offset+len(pktData) {
continue // Skip if buffer too small
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt := <-d.injectCh:
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil // Buffer too small
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
case <-d.closed:
return 0, nil // Device closed
} }
d.mutex.RLock() d.mutex.RLock()
@@ -96,7 +204,7 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
d.mutex.RUnlock() d.mutex.RUnlock()
if len(rules) == 0 { if len(rules) == 0 {
return n, err return n, nil
} }
// Process packets and filter out handled ones // Process packets and filter out handled ones

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
"strings"
"time" "time"
"github.com/fosrl/newt/bind" "github.com/fosrl/newt/bind"
@@ -509,6 +510,11 @@ func StartTunnel(config TunnelConfig) {
} }
// TODO: seperate adding the callback to this so we can init it above with the interface // TODO: seperate adding the callback to this so we can init it above with the interface
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
peerMonitor = peermonitor.NewPeerMonitor( peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) { func(siteID int, connected bool, rtt time.Duration) {
// Find the site config to get endpoint information // Find the site config to get endpoint information
@@ -534,6 +540,8 @@ func StartTunnel(config TunnelConfig) {
olm, olm,
dev, dev,
config.Holepunch, config.Holepunch,
middleDev,
interfaceIP,
) )
for i := range wgData.Sites { for i := range wgData.Sites {

View File

@@ -3,14 +3,27 @@ package peermonitor
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"net/netip"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
middleDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/websocket" "github.com/fosrl/olm/websocket"
"github.com/fosrl/olm/wgtester" "github.com/fosrl/olm/wgtester"
"golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
) )
// PeerMonitorCallback is the function type for connection status change callbacks // PeerMonitorCallback is the function type for connection status change callbacks
@@ -39,11 +52,23 @@ type PeerMonitor struct {
wsClient *websocket.Client wsClient *websocket.Client
device *device.Device device *device.Device
handleRelaySwitch bool // Whether to handle relay switching handleRelaySwitch bool // Whether to handle relay switching
// Netstack fields
middleDev *middleDevice.MiddleDevice
localIP string
stack *stack.Stack
ep *channel.Endpoint
activePorts map[uint16]bool
portsLock sync.Mutex
nsCtx context.Context
nsCancel context.CancelFunc
nsWg sync.WaitGroup
} }
// NewPeerMonitor creates a new peer monitor with the given callback // NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor { func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor {
return &PeerMonitor{ ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*wgtester.Client), monitors: make(map[int]*wgtester.Client),
configs: make(map[int]*WireGuardConfig), configs: make(map[int]*WireGuardConfig),
callback: callback, callback: callback,
@@ -54,7 +79,18 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w
wsClient: wsClient, wsClient: wsClient,
device: device, device: device,
handleRelaySwitch: handleRelaySwitch, handleRelaySwitch: handleRelaySwitch,
middleDev: middleDev,
localIP: localIP,
activePorts: make(map[uint16]bool),
nsCtx: ctx,
nsCancel: cancel,
} }
if err := pm.initNetstack(); err != nil {
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
}
return pm
} }
// SetInterval changes how frequently peers are checked // SetInterval changes how frequently peers are checked
@@ -101,35 +137,32 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
// Check if we're already monitoring this peer
if _, exists := pm.monitors[siteID]; exists { if _, exists := pm.monitors[siteID]; exists {
// Update the endpoint instead of creating a new monitor return nil // Already monitoring
pm.removePeerUnlocked(siteID)
} }
client, err := wgtester.NewClient(endpoint) // Use our custom dialer that uses netstack
client, err := wgtester.NewClient(endpoint, pm.dial)
if err != nil { if err != nil {
return err return err
} }
// Configure the client with our settings
client.SetPacketInterval(pm.interval) client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout) client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts) client.SetMaxAttempts(pm.maxAttempts)
// Store the client and config
pm.monitors[siteID] = client pm.monitors[siteID] = client
pm.configs[siteID] = wgConfig pm.configs[siteID] = wgConfig
// If monitor is already running, start monitoring this peer
if pm.running { if pm.running {
siteIDCopy := siteID // Create a copy for the closure if err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
err = client.StartMonitor(func(status wgtester.ConnectionStatus) { pm.handleConnectionStatusChange(siteID, status)
pm.handleConnectionStatusChange(siteIDCopy, status) }); err != nil {
}) return err
}
} }
return err return nil
} }
// removePeerUnlocked stops monitoring a peer and removes it from the monitor // removePeerUnlocked stops monitoring a peer and removes it from the monitor
@@ -329,3 +362,213 @@ func (pm *PeerMonitor) TestAllPeers() map[int]struct {
return results return results
} }
// initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error {
if pm.localIP == "" {
return fmt.Errorf("local IP not provided")
}
addr, err := netip.ParseAddr(pm.localIP)
if err != nil {
return fmt.Errorf("invalid local IP: %v", err)
}
// Create gvisor netstack
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: true,
}
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
pm.stack = stack.New(stackOpts)
// Create NIC
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
return fmt.Errorf("failed to create NIC: %v", err)
}
// Add IP address
ipBytes := addr.As4()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
}
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
return fmt.Errorf("failed to add protocol address: %v", err)
}
// Add default route
pm.stack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
// Register filter rule on MiddleDevice
// We want to intercept packets destined to our local IP
// But ONLY if they are for ports we are listening on
pm.middleDev.AddRule(addr, pm.handlePacket)
// Start packet sender (Stack -> WG)
pm.nsWg.Add(1)
go pm.runPacketSender()
return nil
}
// handlePacket is called by MiddleDevice when a packet arrives for our IP
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
// Check if it's UDP
proto, ok := util.GetProtocol(packet)
if !ok || proto != 17 { // UDP
return false
}
// Check destination port
port, ok := util.GetDestPort(packet)
if !ok {
return false
}
// Check if we are listening on this port
pm.portsLock.Lock()
active := pm.activePorts[uint16(port)]
pm.portsLock.Unlock()
if !active {
return false
}
// Inject into netstack
version := packet[0] >> 4
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
switch version {
case 4:
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
case 6:
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
default:
pkb.DecRef()
return false
}
pkb.DecRef()
return true // Handled
}
// runPacketSender reads packets from netstack and injects them into WireGuard
func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done()
for {
select {
case <-pm.nsCtx.Done():
return
default:
}
pkt := pm.ep.Read()
if pkt == nil {
time.Sleep(1 * time.Millisecond)
continue
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
}
// dial creates a UDP connection using the netstack
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
if pm.stack == nil {
return nil, fmt.Errorf("netstack not initialized")
}
// Parse remote address
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
// Parse local IP
localIP, err := netip.ParseAddr(pm.localIP)
if err != nil {
return nil, err
}
ipBytes := localIP.As4()
// Create UDP connection
// We bind to port 0 (ephemeral)
laddr := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4(ipBytes),
Port: 0,
}
raddrTcpip := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
Port: uint16(raddr.Port),
}
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
if err != nil {
return nil, err
}
// Get local port
localAddr := conn.LocalAddr().(*net.UDPAddr)
port := uint16(localAddr.Port)
// Register port
pm.portsLock.Lock()
pm.activePorts[port] = true
pm.portsLock.Unlock()
// Wrap connection to cleanup port on close
return &trackedConn{
Conn: conn,
pm: pm,
port: port,
}, nil
}
func (pm *PeerMonitor) removePort(port uint16) {
pm.portsLock.Lock()
delete(pm.activePorts, port)
pm.portsLock.Unlock()
}
type trackedConn struct {
net.Conn
pm *PeerMonitor
port uint16
}
func (c *trackedConn) Close() error {
c.pm.removePort(c.port)
return c.Conn.Close()
}

View File

@@ -26,7 +26,7 @@ const (
// Client handles checking connectivity to a server // Client handles checking connectivity to a server
type Client struct { type Client struct {
conn *net.UDPConn conn net.Conn
serverAddr string serverAddr string
monitorRunning bool monitorRunning bool
monitorLock sync.Mutex monitorLock sync.Mutex
@@ -35,8 +35,12 @@ type Client struct {
packetInterval time.Duration packetInterval time.Duration
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
dialer Dialer
} }
// Dialer is a function that creates a connection
type Dialer func(network, addr string) (net.Conn, error)
// ConnectionStatus represents the current connection state // ConnectionStatus represents the current connection state
type ConnectionStatus struct { type ConnectionStatus struct {
Connected bool Connected bool
@@ -44,13 +48,14 @@ type ConnectionStatus struct {
} }
// NewClient creates a new connection test client // NewClient creates a new connection test client
func NewClient(serverAddr string) (*Client, error) { func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{ return &Client{
serverAddr: serverAddr, serverAddr: serverAddr,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second, packetInterval: 2 * time.Second,
timeout: 500 * time.Millisecond, // Timeout for individual packets timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts maxAttempts: 3, // Default max attempts
dialer: dialer,
}, nil }, nil
} }
@@ -91,12 +96,14 @@ func (c *Client) ensureConnection() error {
return nil return nil
} }
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) var err error
if err != nil { if c.dialer != nil {
return err c.conn, err = c.dialer("udp", c.serverAddr)
} else {
// Fallback to standard net.Dial
c.conn, err = net.Dial("udp", c.serverAddr)
} }
c.conn, err = net.DialUDP("udp", nil, serverAddr)
if err != nil { if err != nil {
return err return err
} }