mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
Netstack is working
This commit is contained in:
@@ -19,15 +19,73 @@ type FilterRule struct {
|
||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||
type MiddleDevice struct {
|
||||
tun.Device
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
rules []FilterRule
|
||||
mutex sync.RWMutex
|
||||
readCh chan readResult
|
||||
injectCh chan []byte
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
bufs [][]byte
|
||||
sizes []int
|
||||
offset int
|
||||
n int
|
||||
err error
|
||||
}
|
||||
|
||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||
return &MiddleDevice{
|
||||
Device: device,
|
||||
rules: make([]FilterRule, 0),
|
||||
d := &MiddleDevice{
|
||||
Device: device,
|
||||
rules: make([]FilterRule, 0),
|
||||
readCh: make(chan readResult),
|
||||
injectCh: make(chan []byte, 100),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go d.pump()
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *MiddleDevice) pump() {
|
||||
const defaultOffset = 16
|
||||
batchSize := d.Device.BatchSize()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-d.closed:
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// Allocate buffers for reading
|
||||
// We allocate new buffers for each read to avoid race conditions
|
||||
// since we pass them to the channel
|
||||
bufs := make([][]byte, batchSize)
|
||||
sizes := make([]int, batchSize)
|
||||
for i := range bufs {
|
||||
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||
}
|
||||
|
||||
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||
|
||||
select {
|
||||
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||
case <-d.closed:
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||
select {
|
||||
case d.injectCh <- packet:
|
||||
case <-d.closed:
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +112,16 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
||||
d.rules = newRules
|
||||
}
|
||||
|
||||
// Close stops the device
|
||||
func (d *MiddleDevice) Close() error {
|
||||
select {
|
||||
case <-d.closed:
|
||||
default:
|
||||
close(d.closed)
|
||||
}
|
||||
return d.Device.Close()
|
||||
}
|
||||
|
||||
// extractDestIP extracts destination IP from packet (fast path)
|
||||
func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
if len(packet) < 20 {
|
||||
@@ -86,9 +154,49 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||
|
||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
n, err = d.Device.Read(bufs, sizes, offset)
|
||||
if err != nil || n == 0 {
|
||||
return n, err
|
||||
select {
|
||||
case res := <-d.readCh:
|
||||
if res.err != nil {
|
||||
return 0, res.err
|
||||
}
|
||||
|
||||
// Copy packets from result to provided buffers
|
||||
count := 0
|
||||
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||
// Handle offset mismatch if necessary
|
||||
// We assume the pump used defaultOffset (16)
|
||||
// If caller asks for different offset, we need to shift
|
||||
src := res.bufs[i]
|
||||
srcOffset := res.offset
|
||||
srcSize := res.sizes[i]
|
||||
|
||||
// Calculate where the packet data starts and ends in src
|
||||
pktData := src[srcOffset : srcOffset+srcSize]
|
||||
|
||||
// Ensure dest buffer is large enough
|
||||
if len(bufs[i]) < offset+len(pktData) {
|
||||
continue // Skip if buffer too small
|
||||
}
|
||||
|
||||
copy(bufs[i][offset:], pktData)
|
||||
sizes[i] = len(pktData)
|
||||
count++
|
||||
}
|
||||
n = count
|
||||
|
||||
case pkt := <-d.injectCh:
|
||||
if len(bufs) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
if len(bufs[0]) < offset+len(pkt) {
|
||||
return 0, nil // Buffer too small
|
||||
}
|
||||
copy(bufs[0][offset:], pkt)
|
||||
sizes[0] = len(pkt)
|
||||
n = 1
|
||||
|
||||
case <-d.closed:
|
||||
return 0, nil // Device closed
|
||||
}
|
||||
|
||||
d.mutex.RLock()
|
||||
@@ -96,7 +204,7 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
|
||||
d.mutex.RUnlock()
|
||||
|
||||
if len(rules) == 0 {
|
||||
return n, err
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Process packets and filter out handled ones
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/bind"
|
||||
@@ -509,6 +510,11 @@ func StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
|
||||
// TODO: seperate adding the callback to this so we can init it above with the interface
|
||||
interfaceIP := wgData.TunnelIP
|
||||
if strings.Contains(interfaceIP, "/") {
|
||||
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
||||
}
|
||||
|
||||
peerMonitor = peermonitor.NewPeerMonitor(
|
||||
func(siteID int, connected bool, rtt time.Duration) {
|
||||
// Find the site config to get endpoint information
|
||||
@@ -534,6 +540,8 @@ func StartTunnel(config TunnelConfig) {
|
||||
olm,
|
||||
dev,
|
||||
config.Holepunch,
|
||||
middleDev,
|
||||
interfaceIP,
|
||||
)
|
||||
|
||||
for i := range wgData.Sites {
|
||||
|
||||
@@ -3,14 +3,27 @@ package peermonitor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
middleDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||
@@ -39,11 +52,23 @@ type PeerMonitor struct {
|
||||
wsClient *websocket.Client
|
||||
device *device.Device
|
||||
handleRelaySwitch bool // Whether to handle relay switching
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
localIP string
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
activePorts map[uint16]bool
|
||||
portsLock sync.Mutex
|
||||
nsCtx context.Context
|
||||
nsCancel context.CancelFunc
|
||||
nsWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||
return &PeerMonitor{
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
configs: make(map[int]*WireGuardConfig),
|
||||
callback: callback,
|
||||
@@ -54,7 +79,18 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w
|
||||
wsClient: wsClient,
|
||||
device: device,
|
||||
handleRelaySwitch: handleRelaySwitch,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
activePorts: make(map[uint16]bool),
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
|
||||
}
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
@@ -101,35 +137,32 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Check if we're already monitoring this peer
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
// Update the endpoint instead of creating a new monitor
|
||||
pm.removePeerUnlocked(siteID)
|
||||
return nil // Already monitoring
|
||||
}
|
||||
|
||||
client, err := wgtester.NewClient(endpoint)
|
||||
// Use our custom dialer that uses netstack
|
||||
client, err := wgtester.NewClient(endpoint, pm.dial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the client with our settings
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
// Store the client and config
|
||||
pm.monitors[siteID] = client
|
||||
pm.configs[siteID] = wgConfig
|
||||
|
||||
// If monitor is already running, start monitoring this peer
|
||||
if pm.running {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
if err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteID, status)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||
@@ -329,3 +362,213 @@ func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// initNetstack initializes the gvisor netstack
|
||||
func (pm *PeerMonitor) initNetstack() error {
|
||||
if pm.localIP == "" {
|
||||
return fmt.Errorf("local IP not provided")
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid local IP: %v", err)
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
|
||||
pm.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
|
||||
return fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
ipBytes := addr.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
pm.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Register filter rule on MiddleDevice
|
||||
// We want to intercept packets destined to our local IP
|
||||
// But ONLY if they are for ports we are listening on
|
||||
pm.middleDev.AddRule(addr, pm.handlePacket)
|
||||
|
||||
// Start packet sender (Stack -> WG)
|
||||
pm.nsWg.Add(1)
|
||||
go pm.runPacketSender()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handlePacket is called by MiddleDevice when a packet arrives for our IP
|
||||
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
|
||||
// Check if it's UDP
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // UDP
|
||||
return false
|
||||
}
|
||||
|
||||
// Check destination port
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if we are listening on this port
|
||||
pm.portsLock.Lock()
|
||||
active := pm.activePorts[uint16(port)]
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
if !active {
|
||||
return false
|
||||
}
|
||||
|
||||
// Inject into netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Handled
|
||||
}
|
||||
|
||||
// runPacketSender reads packets from netstack and injects them into WireGuard
|
||||
func (pm *PeerMonitor) runPacketSender() {
|
||||
defer pm.nsWg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.nsCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
pkt := pm.ep.Read()
|
||||
if pkt == nil {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
pm.middleDev.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// dial creates a UDP connection using the netstack
|
||||
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
|
||||
if pm.stack == nil {
|
||||
return nil, fmt.Errorf("netstack not initialized")
|
||||
}
|
||||
|
||||
// Parse remote address
|
||||
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse local IP
|
||||
localIP, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipBytes := localIP.As4()
|
||||
|
||||
// Create UDP connection
|
||||
// We bind to port 0 (ephemeral)
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: 0,
|
||||
}
|
||||
|
||||
raddrTcpip := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||
Port: uint16(raddr.Port),
|
||||
}
|
||||
|
||||
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get local port
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
port := uint16(localAddr.Port)
|
||||
|
||||
// Register port
|
||||
pm.portsLock.Lock()
|
||||
pm.activePorts[port] = true
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
// Wrap connection to cleanup port on close
|
||||
return &trackedConn{
|
||||
Conn: conn,
|
||||
pm: pm,
|
||||
port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pm *PeerMonitor) removePort(port uint16) {
|
||||
pm.portsLock.Lock()
|
||||
delete(pm.activePorts, port)
|
||||
pm.portsLock.Unlock()
|
||||
}
|
||||
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
pm *PeerMonitor
|
||||
port uint16
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.pm.removePort(c.port)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ const (
|
||||
|
||||
// Client handles checking connectivity to a server
|
||||
type Client struct {
|
||||
conn *net.UDPConn
|
||||
conn net.Conn
|
||||
serverAddr string
|
||||
monitorRunning bool
|
||||
monitorLock sync.Mutex
|
||||
@@ -35,8 +35,12 @@ type Client struct {
|
||||
packetInterval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
dialer Dialer
|
||||
}
|
||||
|
||||
// Dialer is a function that creates a connection
|
||||
type Dialer func(network, addr string) (net.Conn, error)
|
||||
|
||||
// ConnectionStatus represents the current connection state
|
||||
type ConnectionStatus struct {
|
||||
Connected bool
|
||||
@@ -44,13 +48,14 @@ type ConnectionStatus struct {
|
||||
}
|
||||
|
||||
// NewClient creates a new connection test client
|
||||
func NewClient(serverAddr string) (*Client, error) {
|
||||
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||
return &Client{
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
packetInterval: 2 * time.Second,
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -91,12 +96,14 @@ func (c *Client) ensureConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
var err error
|
||||
if c.dialer != nil {
|
||||
c.conn, err = c.dialer("udp", c.serverAddr)
|
||||
} else {
|
||||
// Fallback to standard net.Dial
|
||||
c.conn, err = net.Dial("udp", c.serverAddr)
|
||||
}
|
||||
|
||||
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user