Netstack is working

Former-commit-id: 4fc751ddbc
This commit is contained in:
Owen
2025-11-23 16:24:00 -05:00
parent 47d628af73
commit 7afe842a95
4 changed files with 395 additions and 29 deletions

View File

@@ -19,15 +19,73 @@ type FilterRule struct {
// MiddleDevice wraps a TUN device with packet filtering capabilities
type MiddleDevice struct {
tun.Device
rules []FilterRule
mutex sync.RWMutex
rules []FilterRule
mutex sync.RWMutex
readCh chan readResult
injectCh chan []byte
closed chan struct{}
}
type readResult struct {
bufs [][]byte
sizes []int
offset int
n int
err error
}
// NewMiddleDevice creates a new filtered TUN device wrapper
func NewMiddleDevice(device tun.Device) *MiddleDevice {
return &MiddleDevice{
Device: device,
rules: make([]FilterRule, 0),
d := &MiddleDevice{
Device: device,
rules: make([]FilterRule, 0),
readCh: make(chan readResult),
injectCh: make(chan []byte, 100),
closed: make(chan struct{}),
}
go d.pump()
return d
}
func (d *MiddleDevice) pump() {
const defaultOffset = 16
batchSize := d.Device.BatchSize()
for {
select {
case <-d.closed:
return
default:
}
// Allocate buffers for reading
// We allocate new buffers for each read to avoid race conditions
// since we pass them to the channel
bufs := make([][]byte, batchSize)
sizes := make([]int, batchSize)
for i := range bufs {
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
}
n, err := d.Device.Read(bufs, sizes, defaultOffset)
select {
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
case <-d.closed:
return
}
if err != nil {
return
}
}
}
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
func (d *MiddleDevice) InjectOutbound(packet []byte) {
select {
case d.injectCh <- packet:
case <-d.closed:
}
}
@@ -54,6 +112,16 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
d.rules = newRules
}
// Close stops the device
func (d *MiddleDevice) Close() error {
select {
case <-d.closed:
default:
close(d.closed)
}
return d.Device.Close()
}
// extractDestIP extracts destination IP from packet (fast path)
func extractDestIP(packet []byte) (netip.Addr, bool) {
if len(packet) < 20 {
@@ -86,9 +154,49 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
// Read intercepts packets going UP from the TUN device (towards WireGuard)
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
n, err = d.Device.Read(bufs, sizes, offset)
if err != nil || n == 0 {
return n, err
select {
case res := <-d.readCh:
if res.err != nil {
return 0, res.err
}
// Copy packets from result to provided buffers
count := 0
for i := 0; i < res.n && i < len(bufs); i++ {
// Handle offset mismatch if necessary
// We assume the pump used defaultOffset (16)
// If caller asks for different offset, we need to shift
src := res.bufs[i]
srcOffset := res.offset
srcSize := res.sizes[i]
// Calculate where the packet data starts and ends in src
pktData := src[srcOffset : srcOffset+srcSize]
// Ensure dest buffer is large enough
if len(bufs[i]) < offset+len(pktData) {
continue // Skip if buffer too small
}
copy(bufs[i][offset:], pktData)
sizes[i] = len(pktData)
count++
}
n = count
case pkt := <-d.injectCh:
if len(bufs) == 0 {
return 0, nil
}
if len(bufs[0]) < offset+len(pkt) {
return 0, nil // Buffer too small
}
copy(bufs[0][offset:], pkt)
sizes[0] = len(pkt)
n = 1
case <-d.closed:
return 0, nil // Device closed
}
d.mutex.RLock()
@@ -96,7 +204,7 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
d.mutex.RUnlock()
if len(rules) == 0 {
return n, err
return n, nil
}
// Process packets and filter out handled ones

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"net"
"runtime"
"strings"
"time"
"github.com/fosrl/newt/bind"
@@ -509,6 +510,11 @@ func StartTunnel(config TunnelConfig) {
}
// TODO: seperate adding the callback to this so we can init it above with the interface
interfaceIP := wgData.TunnelIP
if strings.Contains(interfaceIP, "/") {
interfaceIP = strings.Split(interfaceIP, "/")[0]
}
peerMonitor = peermonitor.NewPeerMonitor(
func(siteID int, connected bool, rtt time.Duration) {
// Find the site config to get endpoint information
@@ -534,6 +540,8 @@ func StartTunnel(config TunnelConfig) {
olm,
dev,
config.Holepunch,
middleDev,
interfaceIP,
)
for i := range wgData.Sites {

View File

@@ -3,14 +3,27 @@ package peermonitor
import (
"context"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/util"
middleDevice "github.com/fosrl/olm/device"
"github.com/fosrl/olm/websocket"
"github.com/fosrl/olm/wgtester"
"golang.zx2c4.com/wireguard/device"
"gvisor.dev/gvisor/pkg/buffer"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
)
// PeerMonitorCallback is the function type for connection status change callbacks
@@ -39,11 +52,23 @@ type PeerMonitor struct {
wsClient *websocket.Client
device *device.Device
handleRelaySwitch bool // Whether to handle relay switching
// Netstack fields
middleDev *middleDevice.MiddleDevice
localIP string
stack *stack.Stack
ep *channel.Endpoint
activePorts map[uint16]bool
portsLock sync.Mutex
nsCtx context.Context
nsCancel context.CancelFunc
nsWg sync.WaitGroup
}
// NewPeerMonitor creates a new peer monitor with the given callback
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
return &PeerMonitor{
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor {
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*wgtester.Client),
configs: make(map[int]*WireGuardConfig),
callback: callback,
@@ -54,7 +79,18 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w
wsClient: wsClient,
device: device,
handleRelaySwitch: handleRelaySwitch,
middleDev: middleDev,
localIP: localIP,
activePorts: make(map[uint16]bool),
nsCtx: ctx,
nsCancel: cancel,
}
if err := pm.initNetstack(); err != nil {
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
}
return pm
}
// SetInterval changes how frequently peers are checked
@@ -101,35 +137,32 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC
pm.mutex.Lock()
defer pm.mutex.Unlock()
// Check if we're already monitoring this peer
if _, exists := pm.monitors[siteID]; exists {
// Update the endpoint instead of creating a new monitor
pm.removePeerUnlocked(siteID)
return nil // Already monitoring
}
client, err := wgtester.NewClient(endpoint)
// Use our custom dialer that uses netstack
client, err := wgtester.NewClient(endpoint, pm.dial)
if err != nil {
return err
}
// Configure the client with our settings
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
// Store the client and config
pm.monitors[siteID] = client
pm.configs[siteID] = wgConfig
// If monitor is already running, start monitoring this peer
if pm.running {
siteIDCopy := siteID // Create a copy for the closure
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
pm.handleConnectionStatusChange(siteIDCopy, status)
})
if err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
pm.handleConnectionStatusChange(siteID, status)
}); err != nil {
return err
}
}
return err
return nil
}
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
@@ -329,3 +362,213 @@ func (pm *PeerMonitor) TestAllPeers() map[int]struct {
return results
}
// initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error {
if pm.localIP == "" {
return fmt.Errorf("local IP not provided")
}
addr, err := netip.ParseAddr(pm.localIP)
if err != nil {
return fmt.Errorf("invalid local IP: %v", err)
}
// Create gvisor netstack
stackOpts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
HandleLocal: true,
}
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
pm.stack = stack.New(stackOpts)
// Create NIC
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
return fmt.Errorf("failed to create NIC: %v", err)
}
// Add IP address
ipBytes := addr.As4()
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
}
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
return fmt.Errorf("failed to add protocol address: %v", err)
}
// Add default route
pm.stack.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: 1,
})
// Register filter rule on MiddleDevice
// We want to intercept packets destined to our local IP
// But ONLY if they are for ports we are listening on
pm.middleDev.AddRule(addr, pm.handlePacket)
// Start packet sender (Stack -> WG)
pm.nsWg.Add(1)
go pm.runPacketSender()
return nil
}
// handlePacket is called by MiddleDevice when a packet arrives for our IP
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
// Check if it's UDP
proto, ok := util.GetProtocol(packet)
if !ok || proto != 17 { // UDP
return false
}
// Check destination port
port, ok := util.GetDestPort(packet)
if !ok {
return false
}
// Check if we are listening on this port
pm.portsLock.Lock()
active := pm.activePorts[uint16(port)]
pm.portsLock.Unlock()
if !active {
return false
}
// Inject into netstack
version := packet[0] >> 4
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
Payload: buffer.MakeWithData(packet),
})
switch version {
case 4:
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
case 6:
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
default:
pkb.DecRef()
return false
}
pkb.DecRef()
return true // Handled
}
// runPacketSender reads packets from netstack and injects them into WireGuard
func (pm *PeerMonitor) runPacketSender() {
defer pm.nsWg.Done()
for {
select {
case <-pm.nsCtx.Done():
return
default:
}
pkt := pm.ep.Read()
if pkt == nil {
time.Sleep(1 * time.Millisecond)
continue
}
// Extract packet data
slices := pkt.AsSlices()
if len(slices) > 0 {
var totalSize int
for _, slice := range slices {
totalSize += len(slice)
}
buf := make([]byte, totalSize)
pos := 0
for _, slice := range slices {
copy(buf[pos:], slice)
pos += len(slice)
}
// Inject into MiddleDevice (outbound to WG)
pm.middleDev.InjectOutbound(buf)
}
pkt.DecRef()
}
}
// dial creates a UDP connection using the netstack
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
if pm.stack == nil {
return nil, fmt.Errorf("netstack not initialized")
}
// Parse remote address
raddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
// Parse local IP
localIP, err := netip.ParseAddr(pm.localIP)
if err != nil {
return nil, err
}
ipBytes := localIP.As4()
// Create UDP connection
// We bind to port 0 (ephemeral)
laddr := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4(ipBytes),
Port: 0,
}
raddrTcpip := &tcpip.FullAddress{
NIC: 1,
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
Port: uint16(raddr.Port),
}
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
if err != nil {
return nil, err
}
// Get local port
localAddr := conn.LocalAddr().(*net.UDPAddr)
port := uint16(localAddr.Port)
// Register port
pm.portsLock.Lock()
pm.activePorts[port] = true
pm.portsLock.Unlock()
// Wrap connection to cleanup port on close
return &trackedConn{
Conn: conn,
pm: pm,
port: port,
}, nil
}
func (pm *PeerMonitor) removePort(port uint16) {
pm.portsLock.Lock()
delete(pm.activePorts, port)
pm.portsLock.Unlock()
}
type trackedConn struct {
net.Conn
pm *PeerMonitor
port uint16
}
func (c *trackedConn) Close() error {
c.pm.removePort(c.port)
return c.Conn.Close()
}

View File

@@ -26,7 +26,7 @@ const (
// Client handles checking connectivity to a server
type Client struct {
conn *net.UDPConn
conn net.Conn
serverAddr string
monitorRunning bool
monitorLock sync.Mutex
@@ -35,8 +35,12 @@ type Client struct {
packetInterval time.Duration
timeout time.Duration
maxAttempts int
dialer Dialer
}
// Dialer is a function that creates a connection
type Dialer func(network, addr string) (net.Conn, error)
// ConnectionStatus represents the current connection state
type ConnectionStatus struct {
Connected bool
@@ -44,13 +48,14 @@ type ConnectionStatus struct {
}
// NewClient creates a new connection test client
func NewClient(serverAddr string) (*Client, error) {
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second,
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
}, nil
}
@@ -91,12 +96,14 @@ func (c *Client) ensureConnection() error {
return nil
}
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
if err != nil {
return err
var err error
if c.dialer != nil {
c.conn, err = c.dialer("udp", c.serverAddr)
} else {
// Fallback to standard net.Dial
c.conn, err = net.Dial("udp", c.serverAddr)
}
c.conn, err = net.DialUDP("udp", nil, serverAddr)
if err != nil {
return err
}