From 4fc751ddbcd101faa175e35ae839dd5395cf58bc Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 23 Nov 2025 16:24:00 -0500 Subject: [PATCH] Netstack is working --- device/middle_device.go | 126 +++++++++++++++-- olm/olm.go | 8 ++ peermonitor/peermonitor.go | 271 +++++++++++++++++++++++++++++++++++-- wgtester/wgtester.go | 19 ++- 4 files changed, 395 insertions(+), 29 deletions(-) diff --git a/device/middle_device.go b/device/middle_device.go index 82c13ac..809ce1b 100644 --- a/device/middle_device.go +++ b/device/middle_device.go @@ -19,15 +19,73 @@ type FilterRule struct { // MiddleDevice wraps a TUN device with packet filtering capabilities type MiddleDevice struct { tun.Device - rules []FilterRule - mutex sync.RWMutex + rules []FilterRule + mutex sync.RWMutex + readCh chan readResult + injectCh chan []byte + closed chan struct{} +} + +type readResult struct { + bufs [][]byte + sizes []int + offset int + n int + err error } // NewMiddleDevice creates a new filtered TUN device wrapper func NewMiddleDevice(device tun.Device) *MiddleDevice { - return &MiddleDevice{ - Device: device, - rules: make([]FilterRule, 0), + d := &MiddleDevice{ + Device: device, + rules: make([]FilterRule, 0), + readCh: make(chan readResult), + injectCh: make(chan []byte, 100), + closed: make(chan struct{}), + } + go d.pump() + return d +} + +func (d *MiddleDevice) pump() { + const defaultOffset = 16 + batchSize := d.Device.BatchSize() + + for { + select { + case <-d.closed: + return + default: + } + + // Allocate buffers for reading + // We allocate new buffers for each read to avoid race conditions + // since we pass them to the channel + bufs := make([][]byte, batchSize) + sizes := make([]int, batchSize) + for i := range bufs { + bufs[i] = make([]byte, 2048) // Standard MTU + headroom + } + + n, err := d.Device.Read(bufs, sizes, defaultOffset) + + select { + case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}: + case <-d.closed: + return + } + + if err != nil { + return + } + } +} + +// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN) +func (d *MiddleDevice) InjectOutbound(packet []byte) { + select { + case d.injectCh <- packet: + case <-d.closed: } } @@ -54,6 +112,16 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) { d.rules = newRules } +// Close stops the device +func (d *MiddleDevice) Close() error { + select { + case <-d.closed: + default: + close(d.closed) + } + return d.Device.Close() +} + // extractDestIP extracts destination IP from packet (fast path) func extractDestIP(packet []byte) (netip.Addr, bool) { if len(packet) < 20 { @@ -86,9 +154,49 @@ func extractDestIP(packet []byte) (netip.Addr, bool) { // Read intercepts packets going UP from the TUN device (towards WireGuard) func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { - n, err = d.Device.Read(bufs, sizes, offset) - if err != nil || n == 0 { - return n, err + select { + case res := <-d.readCh: + if res.err != nil { + return 0, res.err + } + + // Copy packets from result to provided buffers + count := 0 + for i := 0; i < res.n && i < len(bufs); i++ { + // Handle offset mismatch if necessary + // We assume the pump used defaultOffset (16) + // If caller asks for different offset, we need to shift + src := res.bufs[i] + srcOffset := res.offset + srcSize := res.sizes[i] + + // Calculate where the packet data starts and ends in src + pktData := src[srcOffset : srcOffset+srcSize] + + // Ensure dest buffer is large enough + if len(bufs[i]) < offset+len(pktData) { + continue // Skip if buffer too small + } + + copy(bufs[i][offset:], pktData) + sizes[i] = len(pktData) + count++ + } + n = count + + case pkt := <-d.injectCh: + if len(bufs) == 0 { + return 0, nil + } + if len(bufs[0]) < offset+len(pkt) { + return 0, nil // Buffer too small + } + copy(bufs[0][offset:], pkt) + sizes[0] = len(pkt) + n = 1 + + case <-d.closed: + return 0, nil // Device closed } d.mutex.RLock() @@ -96,7 +204,7 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err d.mutex.RUnlock() if len(rules) == 0 { - return n, err + return n, nil } // Process packets and filter out handled ones diff --git a/olm/olm.go b/olm/olm.go index 1d4dc5b..3dce73a 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "runtime" + "strings" "time" "github.com/fosrl/newt/bind" @@ -509,6 +510,11 @@ func StartTunnel(config TunnelConfig) { } // TODO: seperate adding the callback to this so we can init it above with the interface + interfaceIP := wgData.TunnelIP + if strings.Contains(interfaceIP, "/") { + interfaceIP = strings.Split(interfaceIP, "/")[0] + } + peerMonitor = peermonitor.NewPeerMonitor( func(siteID int, connected bool, rtt time.Duration) { // Find the site config to get endpoint information @@ -534,6 +540,8 @@ func StartTunnel(config TunnelConfig) { olm, dev, config.Holepunch, + middleDev, + interfaceIP, ) for i := range wgData.Sites { diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index afa8248..d8254f5 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -3,14 +3,27 @@ package peermonitor import ( "context" "fmt" + "net" + "net/netip" "strings" "sync" "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/util" + middleDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/websocket" "github.com/fosrl/olm/wgtester" "golang.zx2c4.com/wireguard/device" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) // PeerMonitorCallback is the function type for connection status change callbacks @@ -39,11 +52,23 @@ type PeerMonitor struct { wsClient *websocket.Client device *device.Device handleRelaySwitch bool // Whether to handle relay switching + + // Netstack fields + middleDev *middleDevice.MiddleDevice + localIP string + stack *stack.Stack + ep *channel.Endpoint + activePorts map[uint16]bool + portsLock sync.Mutex + nsCtx context.Context + nsCancel context.CancelFunc + nsWg sync.WaitGroup } // NewPeerMonitor creates a new peer monitor with the given callback -func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor { - return &PeerMonitor{ +func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor { + ctx, cancel := context.WithCancel(context.Background()) + pm := &PeerMonitor{ monitors: make(map[int]*wgtester.Client), configs: make(map[int]*WireGuardConfig), callback: callback, @@ -54,7 +79,18 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w wsClient: wsClient, device: device, handleRelaySwitch: handleRelaySwitch, + middleDev: middleDev, + localIP: localIP, + activePorts: make(map[uint16]bool), + nsCtx: ctx, + nsCancel: cancel, } + + if err := pm.initNetstack(); err != nil { + logger.Error("Failed to initialize netstack for peer monitor: %v", err) + } + + return pm } // SetInterval changes how frequently peers are checked @@ -101,35 +137,32 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC pm.mutex.Lock() defer pm.mutex.Unlock() - // Check if we're already monitoring this peer if _, exists := pm.monitors[siteID]; exists { - // Update the endpoint instead of creating a new monitor - pm.removePeerUnlocked(siteID) + return nil // Already monitoring } - client, err := wgtester.NewClient(endpoint) + // Use our custom dialer that uses netstack + client, err := wgtester.NewClient(endpoint, pm.dial) if err != nil { return err } - // Configure the client with our settings client.SetPacketInterval(pm.interval) client.SetTimeout(pm.timeout) client.SetMaxAttempts(pm.maxAttempts) - // Store the client and config pm.monitors[siteID] = client pm.configs[siteID] = wgConfig - // If monitor is already running, start monitoring this peer if pm.running { - siteIDCopy := siteID // Create a copy for the closure - err = client.StartMonitor(func(status wgtester.ConnectionStatus) { - pm.handleConnectionStatusChange(siteIDCopy, status) - }) + if err := client.StartMonitor(func(status wgtester.ConnectionStatus) { + pm.handleConnectionStatusChange(siteID, status) + }); err != nil { + return err + } } - return err + return nil } // removePeerUnlocked stops monitoring a peer and removes it from the monitor @@ -329,3 +362,213 @@ func (pm *PeerMonitor) TestAllPeers() map[int]struct { return results } + +// initNetstack initializes the gvisor netstack +func (pm *PeerMonitor) initNetstack() error { + if pm.localIP == "" { + return fmt.Errorf("local IP not provided") + } + + addr, err := netip.ParseAddr(pm.localIP) + if err != nil { + return fmt.Errorf("invalid local IP: %v", err) + } + + // Create gvisor netstack + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG) + pm.stack = stack.New(stackOpts) + + // Create NIC + if err := pm.stack.CreateNIC(1, pm.ep); err != nil { + return fmt.Errorf("failed to create NIC: %v", err) + } + + // Add IP address + ipBytes := addr.As4() + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(), + } + + if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { + return fmt.Errorf("failed to add protocol address: %v", err) + } + + // Add default route + pm.stack.AddRoute(tcpip.Route{ + Destination: header.IPv4EmptySubnet, + NIC: 1, + }) + + // Register filter rule on MiddleDevice + // We want to intercept packets destined to our local IP + // But ONLY if they are for ports we are listening on + pm.middleDev.AddRule(addr, pm.handlePacket) + + // Start packet sender (Stack -> WG) + pm.nsWg.Add(1) + go pm.runPacketSender() + + return nil +} + +// handlePacket is called by MiddleDevice when a packet arrives for our IP +func (pm *PeerMonitor) handlePacket(packet []byte) bool { + // Check if it's UDP + proto, ok := util.GetProtocol(packet) + if !ok || proto != 17 { // UDP + return false + } + + // Check destination port + port, ok := util.GetDestPort(packet) + if !ok { + return false + } + + // Check if we are listening on this port + pm.portsLock.Lock() + active := pm.activePorts[uint16(port)] + pm.portsLock.Unlock() + + if !active { + return false + } + + // Inject into netstack + version := packet[0] >> 4 + pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(packet), + }) + + switch version { + case 4: + pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb) + case 6: + pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb) + default: + pkb.DecRef() + return false + } + + pkb.DecRef() + return true // Handled +} + +// runPacketSender reads packets from netstack and injects them into WireGuard +func (pm *PeerMonitor) runPacketSender() { + defer pm.nsWg.Done() + + for { + select { + case <-pm.nsCtx.Done(): + return + default: + } + + pkt := pm.ep.Read() + if pkt == nil { + time.Sleep(1 * time.Millisecond) + continue + } + + // Extract packet data + slices := pkt.AsSlices() + if len(slices) > 0 { + var totalSize int + for _, slice := range slices { + totalSize += len(slice) + } + + buf := make([]byte, totalSize) + pos := 0 + for _, slice := range slices { + copy(buf[pos:], slice) + pos += len(slice) + } + + // Inject into MiddleDevice (outbound to WG) + pm.middleDev.InjectOutbound(buf) + } + + pkt.DecRef() + } +} + +// dial creates a UDP connection using the netstack +func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) { + if pm.stack == nil { + return nil, fmt.Errorf("netstack not initialized") + } + + // Parse remote address + raddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + + // Parse local IP + localIP, err := netip.ParseAddr(pm.localIP) + if err != nil { + return nil, err + } + ipBytes := localIP.As4() + + // Create UDP connection + // We bind to port 0 (ephemeral) + laddr := &tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFrom4(ipBytes), + Port: 0, + } + + raddrTcpip := &tcpip.FullAddress{ + NIC: 1, + Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())), + Port: uint16(raddr.Port), + } + + conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber) + if err != nil { + return nil, err + } + + // Get local port + localAddr := conn.LocalAddr().(*net.UDPAddr) + port := uint16(localAddr.Port) + + // Register port + pm.portsLock.Lock() + pm.activePorts[port] = true + pm.portsLock.Unlock() + + // Wrap connection to cleanup port on close + return &trackedConn{ + Conn: conn, + pm: pm, + port: port, + }, nil +} + +func (pm *PeerMonitor) removePort(port uint16) { + pm.portsLock.Lock() + delete(pm.activePorts, port) + pm.portsLock.Unlock() +} + +type trackedConn struct { + net.Conn + pm *PeerMonitor + port uint16 +} + +func (c *trackedConn) Close() error { + c.pm.removePort(c.port) + return c.Conn.Close() +} diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go index 28ffdba..b8aacef 100644 --- a/wgtester/wgtester.go +++ b/wgtester/wgtester.go @@ -26,7 +26,7 @@ const ( // Client handles checking connectivity to a server type Client struct { - conn *net.UDPConn + conn net.Conn serverAddr string monitorRunning bool monitorLock sync.Mutex @@ -35,8 +35,12 @@ type Client struct { packetInterval time.Duration timeout time.Duration maxAttempts int + dialer Dialer } +// Dialer is a function that creates a connection +type Dialer func(network, addr string) (net.Conn, error) + // ConnectionStatus represents the current connection state type ConnectionStatus struct { Connected bool @@ -44,13 +48,14 @@ type ConnectionStatus struct { } // NewClient creates a new connection test client -func NewClient(serverAddr string) (*Client, error) { +func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ serverAddr: serverAddr, shutdownCh: make(chan struct{}), packetInterval: 2 * time.Second, timeout: 500 * time.Millisecond, // Timeout for individual packets maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } @@ -91,12 +96,14 @@ func (c *Client) ensureConnection() error { return nil } - serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) - if err != nil { - return err + var err error + if c.dialer != nil { + c.conn, err = c.dialer("udp", c.serverAddr) + } else { + // Fallback to standard net.Dial + c.conn, err = net.Dial("udp", c.serverAddr) } - c.conn, err = net.DialUDP("udp", nil, serverAddr) if err != nil { return err }