mirror of
https://github.com/fosrl/olm.git
synced 2026-02-20 03:46:42 +00:00
Netstack is working
This commit is contained in:
@@ -19,15 +19,73 @@ type FilterRule struct {
|
|||||||
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
// MiddleDevice wraps a TUN device with packet filtering capabilities
|
||||||
type MiddleDevice struct {
|
type MiddleDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
rules []FilterRule
|
rules []FilterRule
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
readCh chan readResult
|
||||||
|
injectCh chan []byte
|
||||||
|
closed chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type readResult struct {
|
||||||
|
bufs [][]byte
|
||||||
|
sizes []int
|
||||||
|
offset int
|
||||||
|
n int
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewMiddleDevice creates a new filtered TUN device wrapper
|
// NewMiddleDevice creates a new filtered TUN device wrapper
|
||||||
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
func NewMiddleDevice(device tun.Device) *MiddleDevice {
|
||||||
return &MiddleDevice{
|
d := &MiddleDevice{
|
||||||
Device: device,
|
Device: device,
|
||||||
rules: make([]FilterRule, 0),
|
rules: make([]FilterRule, 0),
|
||||||
|
readCh: make(chan readResult),
|
||||||
|
injectCh: make(chan []byte, 100),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
go d.pump()
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *MiddleDevice) pump() {
|
||||||
|
const defaultOffset = 16
|
||||||
|
batchSize := d.Device.BatchSize()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-d.closed:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate buffers for reading
|
||||||
|
// We allocate new buffers for each read to avoid race conditions
|
||||||
|
// since we pass them to the channel
|
||||||
|
bufs := make([][]byte, batchSize)
|
||||||
|
sizes := make([]int, batchSize)
|
||||||
|
for i := range bufs {
|
||||||
|
bufs[i] = make([]byte, 2048) // Standard MTU + headroom
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := d.Device.Read(bufs, sizes, defaultOffset)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case d.readCh <- readResult{bufs: bufs, sizes: sizes, offset: defaultOffset, n: n, err: err}:
|
||||||
|
case <-d.closed:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InjectOutbound injects a packet to be read by WireGuard (as if it came from TUN)
|
||||||
|
func (d *MiddleDevice) InjectOutbound(packet []byte) {
|
||||||
|
select {
|
||||||
|
case d.injectCh <- packet:
|
||||||
|
case <-d.closed:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,6 +112,16 @@ func (d *MiddleDevice) RemoveRule(destIP netip.Addr) {
|
|||||||
d.rules = newRules
|
d.rules = newRules
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close stops the device
|
||||||
|
func (d *MiddleDevice) Close() error {
|
||||||
|
select {
|
||||||
|
case <-d.closed:
|
||||||
|
default:
|
||||||
|
close(d.closed)
|
||||||
|
}
|
||||||
|
return d.Device.Close()
|
||||||
|
}
|
||||||
|
|
||||||
// extractDestIP extracts destination IP from packet (fast path)
|
// extractDestIP extracts destination IP from packet (fast path)
|
||||||
func extractDestIP(packet []byte) (netip.Addr, bool) {
|
func extractDestIP(packet []byte) (netip.Addr, bool) {
|
||||||
if len(packet) < 20 {
|
if len(packet) < 20 {
|
||||||
@@ -86,9 +154,49 @@ func extractDestIP(packet []byte) (netip.Addr, bool) {
|
|||||||
|
|
||||||
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
// Read intercepts packets going UP from the TUN device (towards WireGuard)
|
||||||
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
n, err = d.Device.Read(bufs, sizes, offset)
|
select {
|
||||||
if err != nil || n == 0 {
|
case res := <-d.readCh:
|
||||||
return n, err
|
if res.err != nil {
|
||||||
|
return 0, res.err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy packets from result to provided buffers
|
||||||
|
count := 0
|
||||||
|
for i := 0; i < res.n && i < len(bufs); i++ {
|
||||||
|
// Handle offset mismatch if necessary
|
||||||
|
// We assume the pump used defaultOffset (16)
|
||||||
|
// If caller asks for different offset, we need to shift
|
||||||
|
src := res.bufs[i]
|
||||||
|
srcOffset := res.offset
|
||||||
|
srcSize := res.sizes[i]
|
||||||
|
|
||||||
|
// Calculate where the packet data starts and ends in src
|
||||||
|
pktData := src[srcOffset : srcOffset+srcSize]
|
||||||
|
|
||||||
|
// Ensure dest buffer is large enough
|
||||||
|
if len(bufs[i]) < offset+len(pktData) {
|
||||||
|
continue // Skip if buffer too small
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(bufs[i][offset:], pktData)
|
||||||
|
sizes[i] = len(pktData)
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
n = count
|
||||||
|
|
||||||
|
case pkt := <-d.injectCh:
|
||||||
|
if len(bufs) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if len(bufs[0]) < offset+len(pkt) {
|
||||||
|
return 0, nil // Buffer too small
|
||||||
|
}
|
||||||
|
copy(bufs[0][offset:], pkt)
|
||||||
|
sizes[0] = len(pkt)
|
||||||
|
n = 1
|
||||||
|
|
||||||
|
case <-d.closed:
|
||||||
|
return 0, nil // Device closed
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mutex.RLock()
|
d.mutex.RLock()
|
||||||
@@ -96,7 +204,7 @@ func (d *MiddleDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err
|
|||||||
d.mutex.RUnlock()
|
d.mutex.RUnlock()
|
||||||
|
|
||||||
if len(rules) == 0 {
|
if len(rules) == 0 {
|
||||||
return n, err
|
return n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process packets and filter out handled ones
|
// Process packets and filter out handled ones
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/bind"
|
"github.com/fosrl/newt/bind"
|
||||||
@@ -509,6 +510,11 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: seperate adding the callback to this so we can init it above with the interface
|
// TODO: seperate adding the callback to this so we can init it above with the interface
|
||||||
|
interfaceIP := wgData.TunnelIP
|
||||||
|
if strings.Contains(interfaceIP, "/") {
|
||||||
|
interfaceIP = strings.Split(interfaceIP, "/")[0]
|
||||||
|
}
|
||||||
|
|
||||||
peerMonitor = peermonitor.NewPeerMonitor(
|
peerMonitor = peermonitor.NewPeerMonitor(
|
||||||
func(siteID int, connected bool, rtt time.Duration) {
|
func(siteID int, connected bool, rtt time.Duration) {
|
||||||
// Find the site config to get endpoint information
|
// Find the site config to get endpoint information
|
||||||
@@ -534,6 +540,8 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
olm,
|
olm,
|
||||||
dev,
|
dev,
|
||||||
config.Holepunch,
|
config.Holepunch,
|
||||||
|
middleDev,
|
||||||
|
interfaceIP,
|
||||||
)
|
)
|
||||||
|
|
||||||
for i := range wgData.Sites {
|
for i := range wgData.Sites {
|
||||||
|
|||||||
@@ -3,14 +3,27 @@ package peermonitor
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
|
"github.com/fosrl/newt/util"
|
||||||
|
middleDevice "github.com/fosrl/olm/device"
|
||||||
"github.com/fosrl/olm/websocket"
|
"github.com/fosrl/olm/websocket"
|
||||||
"github.com/fosrl/olm/wgtester"
|
"github.com/fosrl/olm/wgtester"
|
||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||||
@@ -39,11 +52,23 @@ type PeerMonitor struct {
|
|||||||
wsClient *websocket.Client
|
wsClient *websocket.Client
|
||||||
device *device.Device
|
device *device.Device
|
||||||
handleRelaySwitch bool // Whether to handle relay switching
|
handleRelaySwitch bool // Whether to handle relay switching
|
||||||
|
|
||||||
|
// Netstack fields
|
||||||
|
middleDev *middleDevice.MiddleDevice
|
||||||
|
localIP string
|
||||||
|
stack *stack.Stack
|
||||||
|
ep *channel.Endpoint
|
||||||
|
activePorts map[uint16]bool
|
||||||
|
portsLock sync.Mutex
|
||||||
|
nsCtx context.Context
|
||||||
|
nsCancel context.CancelFunc
|
||||||
|
nsWg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor {
|
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor {
|
||||||
return &PeerMonitor{
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
pm := &PeerMonitor{
|
||||||
monitors: make(map[int]*wgtester.Client),
|
monitors: make(map[int]*wgtester.Client),
|
||||||
configs: make(map[int]*WireGuardConfig),
|
configs: make(map[int]*WireGuardConfig),
|
||||||
callback: callback,
|
callback: callback,
|
||||||
@@ -54,7 +79,18 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w
|
|||||||
wsClient: wsClient,
|
wsClient: wsClient,
|
||||||
device: device,
|
device: device,
|
||||||
handleRelaySwitch: handleRelaySwitch,
|
handleRelaySwitch: handleRelaySwitch,
|
||||||
|
middleDev: middleDev,
|
||||||
|
localIP: localIP,
|
||||||
|
activePorts: make(map[uint16]bool),
|
||||||
|
nsCtx: ctx,
|
||||||
|
nsCancel: cancel,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := pm.initNetstack(); err != nil {
|
||||||
|
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pm
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInterval changes how frequently peers are checked
|
// SetInterval changes how frequently peers are checked
|
||||||
@@ -101,35 +137,32 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardC
|
|||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.mutex.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
// Check if we're already monitoring this peer
|
|
||||||
if _, exists := pm.monitors[siteID]; exists {
|
if _, exists := pm.monitors[siteID]; exists {
|
||||||
// Update the endpoint instead of creating a new monitor
|
return nil // Already monitoring
|
||||||
pm.removePeerUnlocked(siteID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := wgtester.NewClient(endpoint)
|
// Use our custom dialer that uses netstack
|
||||||
|
client, err := wgtester.NewClient(endpoint, pm.dial)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure the client with our settings
|
|
||||||
client.SetPacketInterval(pm.interval)
|
client.SetPacketInterval(pm.interval)
|
||||||
client.SetTimeout(pm.timeout)
|
client.SetTimeout(pm.timeout)
|
||||||
client.SetMaxAttempts(pm.maxAttempts)
|
client.SetMaxAttempts(pm.maxAttempts)
|
||||||
|
|
||||||
// Store the client and config
|
|
||||||
pm.monitors[siteID] = client
|
pm.monitors[siteID] = client
|
||||||
pm.configs[siteID] = wgConfig
|
pm.configs[siteID] = wgConfig
|
||||||
|
|
||||||
// If monitor is already running, start monitoring this peer
|
|
||||||
if pm.running {
|
if pm.running {
|
||||||
siteIDCopy := siteID // Create a copy for the closure
|
if err := client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||||
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
pm.handleConnectionStatusChange(siteID, status)
|
||||||
pm.handleConnectionStatusChange(siteIDCopy, status)
|
}); err != nil {
|
||||||
})
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
// removePeerUnlocked stops monitoring a peer and removes it from the monitor
|
||||||
@@ -329,3 +362,213 @@ func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initNetstack initializes the gvisor netstack
|
||||||
|
func (pm *PeerMonitor) initNetstack() error {
|
||||||
|
if pm.localIP == "" {
|
||||||
|
return fmt.Errorf("local IP not provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := netip.ParseAddr(pm.localIP)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid local IP: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create gvisor netstack
|
||||||
|
stackOpts := stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol},
|
||||||
|
HandleLocal: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG)
|
||||||
|
pm.stack = stack.New(stackOpts)
|
||||||
|
|
||||||
|
// Create NIC
|
||||||
|
if err := pm.stack.CreateNIC(1, pm.ep); err != nil {
|
||||||
|
return fmt.Errorf("failed to create NIC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IP address
|
||||||
|
ipBytes := addr.As4()
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv4.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil {
|
||||||
|
return fmt.Errorf("failed to add protocol address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add default route
|
||||||
|
pm.stack.AddRoute(tcpip.Route{
|
||||||
|
Destination: header.IPv4EmptySubnet,
|
||||||
|
NIC: 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Register filter rule on MiddleDevice
|
||||||
|
// We want to intercept packets destined to our local IP
|
||||||
|
// But ONLY if they are for ports we are listening on
|
||||||
|
pm.middleDev.AddRule(addr, pm.handlePacket)
|
||||||
|
|
||||||
|
// Start packet sender (Stack -> WG)
|
||||||
|
pm.nsWg.Add(1)
|
||||||
|
go pm.runPacketSender()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handlePacket is called by MiddleDevice when a packet arrives for our IP
|
||||||
|
func (pm *PeerMonitor) handlePacket(packet []byte) bool {
|
||||||
|
// Check if it's UDP
|
||||||
|
proto, ok := util.GetProtocol(packet)
|
||||||
|
if !ok || proto != 17 { // UDP
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check destination port
|
||||||
|
port, ok := util.GetDestPort(packet)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we are listening on this port
|
||||||
|
pm.portsLock.Lock()
|
||||||
|
active := pm.activePorts[uint16(port)]
|
||||||
|
pm.portsLock.Unlock()
|
||||||
|
|
||||||
|
if !active {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject into netstack
|
||||||
|
version := packet[0] >> 4
|
||||||
|
pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(packet),
|
||||||
|
})
|
||||||
|
|
||||||
|
switch version {
|
||||||
|
case 4:
|
||||||
|
pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb)
|
||||||
|
case 6:
|
||||||
|
pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb)
|
||||||
|
default:
|
||||||
|
pkb.DecRef()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
pkb.DecRef()
|
||||||
|
return true // Handled
|
||||||
|
}
|
||||||
|
|
||||||
|
// runPacketSender reads packets from netstack and injects them into WireGuard
|
||||||
|
func (pm *PeerMonitor) runPacketSender() {
|
||||||
|
defer pm.nsWg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-pm.nsCtx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := pm.ep.Read()
|
||||||
|
if pkt == nil {
|
||||||
|
time.Sleep(1 * time.Millisecond)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract packet data
|
||||||
|
slices := pkt.AsSlices()
|
||||||
|
if len(slices) > 0 {
|
||||||
|
var totalSize int
|
||||||
|
for _, slice := range slices {
|
||||||
|
totalSize += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := make([]byte, totalSize)
|
||||||
|
pos := 0
|
||||||
|
for _, slice := range slices {
|
||||||
|
copy(buf[pos:], slice)
|
||||||
|
pos += len(slice)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Inject into MiddleDevice (outbound to WG)
|
||||||
|
pm.middleDev.InjectOutbound(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt.DecRef()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dial creates a UDP connection using the netstack
|
||||||
|
func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) {
|
||||||
|
if pm.stack == nil {
|
||||||
|
return nil, fmt.Errorf("netstack not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse remote address
|
||||||
|
raddr, err := net.ResolveUDPAddr("udp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse local IP
|
||||||
|
localIP, err := netip.ParseAddr(pm.localIP)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ipBytes := localIP.As4()
|
||||||
|
|
||||||
|
// Create UDP connection
|
||||||
|
// We bind to port 0 (ephemeral)
|
||||||
|
laddr := &tcpip.FullAddress{
|
||||||
|
NIC: 1,
|
||||||
|
Addr: tcpip.AddrFrom4(ipBytes),
|
||||||
|
Port: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
raddrTcpip := &tcpip.FullAddress{
|
||||||
|
NIC: 1,
|
||||||
|
Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())),
|
||||||
|
Port: uint16(raddr.Port),
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get local port
|
||||||
|
localAddr := conn.LocalAddr().(*net.UDPAddr)
|
||||||
|
port := uint16(localAddr.Port)
|
||||||
|
|
||||||
|
// Register port
|
||||||
|
pm.portsLock.Lock()
|
||||||
|
pm.activePorts[port] = true
|
||||||
|
pm.portsLock.Unlock()
|
||||||
|
|
||||||
|
// Wrap connection to cleanup port on close
|
||||||
|
return &trackedConn{
|
||||||
|
Conn: conn,
|
||||||
|
pm: pm,
|
||||||
|
port: port,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pm *PeerMonitor) removePort(port uint16) {
|
||||||
|
pm.portsLock.Lock()
|
||||||
|
delete(pm.activePorts, port)
|
||||||
|
pm.portsLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
type trackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
pm *PeerMonitor
|
||||||
|
port uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackedConn) Close() error {
|
||||||
|
c.pm.removePort(c.port)
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ const (
|
|||||||
|
|
||||||
// Client handles checking connectivity to a server
|
// Client handles checking connectivity to a server
|
||||||
type Client struct {
|
type Client struct {
|
||||||
conn *net.UDPConn
|
conn net.Conn
|
||||||
serverAddr string
|
serverAddr string
|
||||||
monitorRunning bool
|
monitorRunning bool
|
||||||
monitorLock sync.Mutex
|
monitorLock sync.Mutex
|
||||||
@@ -35,8 +35,12 @@ type Client struct {
|
|||||||
packetInterval time.Duration
|
packetInterval time.Duration
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
maxAttempts int
|
maxAttempts int
|
||||||
|
dialer Dialer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Dialer is a function that creates a connection
|
||||||
|
type Dialer func(network, addr string) (net.Conn, error)
|
||||||
|
|
||||||
// ConnectionStatus represents the current connection state
|
// ConnectionStatus represents the current connection state
|
||||||
type ConnectionStatus struct {
|
type ConnectionStatus struct {
|
||||||
Connected bool
|
Connected bool
|
||||||
@@ -44,13 +48,14 @@ type ConnectionStatus struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewClient creates a new connection test client
|
// NewClient creates a new connection test client
|
||||||
func NewClient(serverAddr string) (*Client, error) {
|
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||||
return &Client{
|
return &Client{
|
||||||
serverAddr: serverAddr,
|
serverAddr: serverAddr,
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
packetInterval: 2 * time.Second,
|
packetInterval: 2 * time.Second,
|
||||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||||
maxAttempts: 3, // Default max attempts
|
maxAttempts: 3, // Default max attempts
|
||||||
|
dialer: dialer,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -91,12 +96,14 @@ func (c *Client) ensureConnection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
|
var err error
|
||||||
if err != nil {
|
if c.dialer != nil {
|
||||||
return err
|
c.conn, err = c.dialer("udp", c.serverAddr)
|
||||||
|
} else {
|
||||||
|
// Fallback to standard net.Dial
|
||||||
|
c.conn, err = net.Dial("udp", c.serverAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn, err = net.DialUDP("udp", nil, serverAddr)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user