mirror of
https://github.com/fosrl/olm.git
synced 2026-02-09 22:46:39 +00:00
Netstack is working
This commit is contained in:
@@ -3,14 +3,27 @@ package peermonitor
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/util"
|
||||
middleDevice "github.com/fosrl/olm/device"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
"golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||
)
|
||||
|
||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||
@@ -39,11 +52,23 @@ type PeerMonitor struct {
|
||||
wsClient *websocket.Client
|
||||
device *device.Device
|
||||
handleRelaySwitch bool // Whether to handle relay switching
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
localIP string
|
||||
stack *stack.Stack
|
||||
ep *channel.Endpoint
|
||||
activePorts map[uint16]bool
|
||||
portsLock sync.Mutex
|
||||
nsCtx context.Context
|
||||
nsCancel context.CancelFunc
|
||||
nsWg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
||||
return &PeerMonitor{
|
||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
configs: make(map[int]*WireGuardConfig),
|
||||
callback: callback,
|
||||
@@ -54,7 +79,18 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w
|
||||
wsClient: wsClient,
|
||||
device: device,
|
||||
handleRelaySwitch: handleRelaySwitch,
|
||||
middleDev: middleDev,
|
||||
localIP: localIP,
|
||||
activePorts: make(map[uint16]bool),
|
||||
nsCtx: ctx,
|
||||
nsCancel: cancel,
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
|
||||
}
|
||||
|
||||
return pm
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
@@ -101,35 +137,32 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Check if we're already monitoring this peer
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
// Update the endpoint instead of creating a new monitor
|
||||
pm.removePeerUnlocked(siteID)
|
||||
return nil // Already monitoring
|
||||
}
|
||||
|
||||
client, err := wgtester.NewClient(endpoint)
|
||||
// Use our custom dialer that uses netstack
|
||||
client, err := wgtester.NewClient(endpoint, pm.dial)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the client with our settings
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
// Store the client and config
|
||||
pm.monitors[siteID] = client
|
||||
pm.configs[siteID] = wgConfig
|
||||
|
||||
// If monitor is already running, start monitoring this peer
|
||||
if pm.running {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
||||
})
|
||||
if err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.handleConnectionStatusChange(siteID, status)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||
@@ -329,3 +362,213 @@ func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// initNetstack initializes the gvisor netstack
|
||||
func (pm *PeerMonitor) initNetstack() error {
|
||||
if pm.localIP == "" {
|
||||
return fmt.Errorf("local IP not provided")
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid local IP: %v", err)
|
||||
}
|
||||
|
||||
// Create gvisor netstack
|
||||
stackOpts := stack.Options{
|
||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||
HandleLocal: true,
|
||||
}
|
||||
|
||||
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
|
||||
pm.stack = stack.New(stackOpts)
|
||||
|
||||
// Create NIC
|
||||
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
|
||||
return fmt.Errorf("failed to create NIC: %v", err)
|
||||
}
|
||||
|
||||
// Add IP address
|
||||
ipBytes := addr.As4()
|
||||
protoAddr := tcpip.ProtocolAddress{
|
||||
Protocol: ipv4.ProtocolNumber,
|
||||
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||
}
|
||||
|
||||
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||
return fmt.Errorf("failed to add protocol address: %v", err)
|
||||
}
|
||||
|
||||
// Add default route
|
||||
pm.stack.AddRoute(tcpip.Route{
|
||||
Destination: header.IPv4EmptySubnet,
|
||||
NIC: 1,
|
||||
})
|
||||
|
||||
// Register filter rule on MiddleDevice
|
||||
// We want to intercept packets destined to our local IP
|
||||
// But ONLY if they are for ports we are listening on
|
||||
pm.middleDev.AddRule(addr, pm.handlePacket)
|
||||
|
||||
// Start packet sender (Stack -> WG)
|
||||
pm.nsWg.Add(1)
|
||||
go pm.runPacketSender()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handlePacket is called by MiddleDevice when a packet arrives for our IP
|
||||
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
|
||||
// Check if it's UDP
|
||||
proto, ok := util.GetProtocol(packet)
|
||||
if !ok || proto != 17 { // UDP
|
||||
return false
|
||||
}
|
||||
|
||||
// Check destination port
|
||||
port, ok := util.GetDestPort(packet)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check if we are listening on this port
|
||||
pm.portsLock.Lock()
|
||||
active := pm.activePorts[uint16(port)]
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
if !active {
|
||||
return false
|
||||
}
|
||||
|
||||
// Inject into netstack
|
||||
version := packet[0] >> 4
|
||||
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||
Payload: buffer.MakeWithData(packet),
|
||||
})
|
||||
|
||||
switch version {
|
||||
case 4:
|
||||
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||
case 6:
|
||||
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||
default:
|
||||
pkb.DecRef()
|
||||
return false
|
||||
}
|
||||
|
||||
pkb.DecRef()
|
||||
return true // Handled
|
||||
}
|
||||
|
||||
// runPacketSender reads packets from netstack and injects them into WireGuard
|
||||
func (pm *PeerMonitor) runPacketSender() {
|
||||
defer pm.nsWg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-pm.nsCtx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
pkt := pm.ep.Read()
|
||||
if pkt == nil {
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract packet data
|
||||
slices := pkt.AsSlices()
|
||||
if len(slices) > 0 {
|
||||
var totalSize int
|
||||
for _, slice := range slices {
|
||||
totalSize += len(slice)
|
||||
}
|
||||
|
||||
buf := make([]byte, totalSize)
|
||||
pos := 0
|
||||
for _, slice := range slices {
|
||||
copy(buf[pos:], slice)
|
||||
pos += len(slice)
|
||||
}
|
||||
|
||||
// Inject into MiddleDevice (outbound to WG)
|
||||
pm.middleDev.InjectOutbound(buf)
|
||||
}
|
||||
|
||||
pkt.DecRef()
|
||||
}
|
||||
}
|
||||
|
||||
// dial creates a UDP connection using the netstack
|
||||
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
|
||||
if pm.stack == nil {
|
||||
return nil, fmt.Errorf("netstack not initialized")
|
||||
}
|
||||
|
||||
// Parse remote address
|
||||
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Parse local IP
|
||||
localIP, err := netip.ParseAddr(pm.localIP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ipBytes := localIP.As4()
|
||||
|
||||
// Create UDP connection
|
||||
// We bind to port 0 (ephemeral)
|
||||
laddr := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4(ipBytes),
|
||||
Port: 0,
|
||||
}
|
||||
|
||||
raddrTcpip := &tcpip.FullAddress{
|
||||
NIC: 1,
|
||||
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||
Port: uint16(raddr.Port),
|
||||
}
|
||||
|
||||
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get local port
|
||||
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||
port := uint16(localAddr.Port)
|
||||
|
||||
// Register port
|
||||
pm.portsLock.Lock()
|
||||
pm.activePorts[port] = true
|
||||
pm.portsLock.Unlock()
|
||||
|
||||
// Wrap connection to cleanup port on close
|
||||
return &trackedConn{
|
||||
Conn: conn,
|
||||
pm: pm,
|
||||
port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (pm *PeerMonitor) removePort(port uint16) {
|
||||
pm.portsLock.Lock()
|
||||
delete(pm.activePorts, port)
|
||||
pm.portsLock.Unlock()
|
||||
}
|
||||
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
pm *PeerMonitor
|
||||
port uint16
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.pm.removePort(c.port)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user