package monitor import ( "context" "fmt" "net" "net/netip" "sync" "time" "github.com/fosrl/newt/bind" "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" "github.com/fosrl/olm/api" middleDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/websocket" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { monitors map[int]*Client mutex sync.Mutex running bool timeout time.Duration maxAttempts int wsClient *websocket.Client // Netstack fields middleDev *middleDevice.MiddleDevice localIP string stack *stack.Stack ep *channel.Endpoint activePorts map[uint16]bool portsLock sync.RWMutex nsCtx context.Context nsCancel context.CancelFunc nsWg sync.WaitGroup // Holepunch testing fields sharedBind *bind.SharedBind holepunchTester *holepunch.HolepunchTester holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status holepunchStopChan chan struct{} holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed holepunchMaxAttempts int // max consecutive failures before triggering relay holepunchFailures map[int]int // siteID -> consecutive failure count // Exponential backoff fields for holepunch monitor defaultHolepunchMinInterval time.Duration // Minimum interval (initial) defaultHolepunchMaxInterval time.Duration holepunchMinInterval time.Duration // Minimum interval (initial) holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) holepunchBackoffMultiplier float64 // Multiplier for each stable check holepunchStableCount map[int]int // siteID -> consecutive stable status count holepunchCurrentInterval time.Duration // Current interval with backoff applied // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts rapidTestTimeout time.Duration // timeout for each rapid test attempt rapidTestMaxAttempts int // max attempts during rapid test phase // API server for status updates apiServer *api.API // WG connection status tracking wgConnectionStatus map[int]bool // siteID -> WG connected status } // NewPeerMonitor creates a new peer monitor with the given callback func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind, apiServer *api.API) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, middleDev: middleDev, localIP: localIP, activePorts: make(map[uint16]bool), nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), relayedPeers: make(map[int]bool), holepunchMaxAttempts: 3, // Trigger relay after 3 consecutive failures holepunchFailures: make(map[int]int), // Rapid initial test settings: complete within ~1.5 seconds rapidTestInterval: 200 * time.Millisecond, // 200ms between attempts rapidTestTimeout: 400 * time.Millisecond, // 400ms timeout per attempt rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total apiServer: apiServer, wgConnectionStatus: make(map[int]bool), // Exponential backoff settings for holepunch monitor defaultHolepunchMinInterval: 2 * time.Second, defaultHolepunchMaxInterval: 30 * time.Second, holepunchMinInterval: 2 * time.Second, holepunchMaxInterval: 30 * time.Second, holepunchBackoffMultiplier: 1.5, holepunchStableCount: make(map[int]int), holepunchCurrentInterval: 2 * time.Second, holepunchUpdateChan: make(chan struct{}, 1), } if err := pm.initNetstack(); err != nil { logger.Error("Failed to initialize netstack for peer monitor: %v", err) } // Initialize holepunch tester if sharedBind is available if sharedBind != nil { pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind) } return pm } // SetInterval changes how frequently peers are checked func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() // Update interval for all existing monitors for _, client := range pm.monitors { client.SetPacketInterval(minInterval, maxInterval) } logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval) } func (pm *PeerMonitor) ResetPeerInterval() { pm.mutex.Lock() defer pm.mutex.Unlock() // Update interval for all existing monitors for _, client := range pm.monitors { client.ResetPacketInterval() } } // SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() pm.holepunchMinInterval = minInterval pm.holepunchMaxInterval = maxInterval // Reset current interval to the new minimum pm.holepunchCurrentInterval = minInterval updateChan := pm.holepunchUpdateChan pm.mutex.Unlock() logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval) // Signal the goroutine to apply the new interval if running if updateChan != nil { select { case updateChan <- struct{}{}: default: // Channel full or closed, skip } } } // GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() return pm.holepunchMinInterval, pm.holepunchMaxInterval } func (pm *PeerMonitor) ResetPeerHolepunchInterval() { pm.mutex.Lock() pm.holepunchMinInterval = pm.defaultHolepunchMinInterval pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval updateChan := pm.holepunchUpdateChan pm.mutex.Unlock() logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval) // Signal the goroutine to apply the new interval if running if updateChan != nil { select { case updateChan <- struct{}{}: default: // Channel full or closed, skip } } } // AddPeer adds a new peer to monitor func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error { pm.mutex.Lock() defer pm.mutex.Unlock() if _, exists := pm.monitors[siteID]; exists { return nil // Already monitoring } // Use our custom dialer that uses netstack client, err := NewClient(endpoint, pm.dial) if err != nil { return err } pm.monitors[siteID] = client pm.holepunchEndpoints[siteID] = holepunchEndpoint pm.holepunchStatus[siteID] = false // Initially unknown/disconnected if pm.running { if err := client.StartMonitor(func(status ConnectionStatus) { pm.handleConnectionStatusChange(siteID, status) }); err != nil { return err } } return nil } // update holepunch endpoint for a peer func (pm *PeerMonitor) UpdateHolepunchEndpoint(siteID int, endpoint string) { // Short delay to allow WireGuard peer reconfiguration to complete // The NAT mapping refresh is handled separately by TriggerHolePunch in olm.go pm.mutex.Lock() defer pm.mutex.Unlock() pm.holepunchEndpoints[siteID] = endpoint logger.Debug("Updated holepunch endpoint for site %d to %s", siteID, endpoint) } // RapidTestPeer performs a rapid connectivity test for a newly added peer. // This is designed to quickly determine if holepunch is viable within ~1-2 seconds. // Returns true if the connection is viable (holepunch works), false if it should relay. func (pm *PeerMonitor) RapidTestPeer(siteID int, endpoint string) bool { if pm.holepunchTester == nil { logger.Warn("Cannot perform rapid test: holepunch tester not initialized") return false } pm.mutex.Lock() interval := pm.rapidTestInterval timeout := pm.rapidTestTimeout maxAttempts := pm.rapidTestMaxAttempts pm.mutex.Unlock() logger.Info("Starting rapid holepunch test for site %d at %s (max %d attempts, %v timeout each)", siteID, endpoint, maxAttempts, timeout) for attempt := 1; attempt <= maxAttempts; attempt++ { result := pm.holepunchTester.TestEndpoint(endpoint, timeout) if result.Success { logger.Info("Rapid test: site %d holepunch SUCCEEDED on attempt %d (RTT: %v)", siteID, attempt, result.RTT) // Update status pm.mutex.Lock() pm.holepunchStatus[siteID] = true pm.holepunchFailures[siteID] = 0 pm.mutex.Unlock() return true } if attempt < maxAttempts { time.Sleep(interval) } } logger.Warn("Rapid test: site %d holepunch FAILED after %d attempts, will relay", siteID, maxAttempts) // Update status to reflect failure pm.mutex.Lock() pm.holepunchStatus[siteID] = false pm.holepunchFailures[siteID] = maxAttempts pm.mutex.Unlock() return false } // UpdatePeerEndpoint updates the monitor endpoint for a peer func (pm *PeerMonitor) UpdatePeerEndpoint(siteID int, monitorPeer string) { pm.mutex.Lock() defer pm.mutex.Unlock() client, exists := pm.monitors[siteID] if !exists { logger.Warn("Cannot update endpoint: peer %d not found in monitor", siteID) return } // Update the client's server address client.UpdateServerAddr(monitorPeer) logger.Info("Updated monitor endpoint for site %d to %s", siteID, monitorPeer) } // removePeerUnlocked stops monitoring a peer and removes it from the monitor // This function assumes the mutex is already held by the caller func (pm *PeerMonitor) removePeerUnlocked(siteID int) { client, exists := pm.monitors[siteID] if !exists { return } client.StopMonitor() client.Close() delete(pm.monitors, siteID) } // RemovePeer stops monitoring a peer and removes it from the monitor func (pm *PeerMonitor) RemovePeer(siteID int) { pm.mutex.Lock() defer pm.mutex.Unlock() // remove the holepunch endpoint info delete(pm.holepunchEndpoints, siteID) delete(pm.holepunchStatus, siteID) delete(pm.relayedPeers, siteID) delete(pm.holepunchFailures, siteID) pm.removePeerUnlocked(siteID) } func (pm *PeerMonitor) RemoveHolepunchEndpoint(siteID int) { pm.mutex.Lock() defer pm.mutex.Unlock() delete(pm.holepunchEndpoints, siteID) } // Start begins monitoring all peers func (pm *PeerMonitor) Start() { pm.mutex.Lock() defer pm.mutex.Unlock() if pm.running { return // Already running } pm.running = true // Start monitoring all peers for siteID, client := range pm.monitors { siteIDCopy := siteID // Create a copy for the closure err := client.StartMonitor(func(status ConnectionStatus) { pm.handleConnectionStatusChange(siteIDCopy, status) }) if err != nil { logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err) continue } logger.Info("Started monitoring peer %d\n", siteID) } pm.startHolepunchMonitor() } // handleConnectionStatusChange is called when a peer's connection status changes func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) { pm.mutex.Lock() previousStatus, exists := pm.wgConnectionStatus[siteID] pm.wgConnectionStatus[siteID] = status.Connected isRelayed := pm.relayedPeers[siteID] endpoint := pm.holepunchEndpoints[siteID] pm.mutex.Unlock() // Log status changes if !exists || previousStatus != status.Connected { if status.Connected { logger.Info("WireGuard connection to site %d is CONNECTED (RTT: %v)", siteID, status.RTT) } else { logger.Warn("WireGuard connection to site %d is DISCONNECTED", siteID) } } // Update API with connection status if pm.apiServer != nil { pm.apiServer.UpdatePeerStatus(siteID, status.Connected, status.RTT, endpoint, isRelayed) } } // sendRelay sends a relay message to the server func (pm *PeerMonitor) sendRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ "siteId": siteID, }) if err != nil { logger.Error("Failed to send registration message: %v", err) return err } logger.Info("Sent relay message") return nil } // RequestRelay is a public method to request relay for a peer. // This is used when rapid initial testing determines holepunch is not viable. func (pm *PeerMonitor) RequestRelay(siteID int) error { return pm.sendRelay(siteID) } // sendUnRelay sends an unrelay message to the server func (pm *PeerMonitor) sendUnRelay(siteID int) error { if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } err := pm.wsClient.SendMessage("olm/wg/unrelay", map[string]interface{}{ "siteId": siteID, }) if err != nil { logger.Error("Failed to send registration message: %v", err) return err } logger.Info("Sent unrelay message") return nil } // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) pm.stopHolepunchMonitor() pm.mutex.Lock() defer pm.mutex.Unlock() if !pm.running { return } pm.running = false // Stop all monitors for _, client := range pm.monitors { client.StopMonitor() } } // MarkPeerRelayed marks a peer as currently using relay func (pm *PeerMonitor) MarkPeerRelayed(siteID int, relayed bool) { pm.mutex.Lock() defer pm.mutex.Unlock() pm.relayedPeers[siteID] = relayed if relayed { // Reset failure count when marked as relayed pm.holepunchFailures[siteID] = 0 } } // IsPeerRelayed returns whether a peer is currently using relay func (pm *PeerMonitor) IsPeerRelayed(siteID int) bool { pm.mutex.Lock() defer pm.mutex.Unlock() return pm.relayedPeers[siteID] } // startHolepunchMonitor starts the holepunch connection monitoring // Note: This function assumes the mutex is already held by the caller (called from Start()) func (pm *PeerMonitor) startHolepunchMonitor() error { if pm.holepunchTester == nil { return fmt.Errorf("holepunch tester not initialized (sharedBind not provided)") } if pm.holepunchStopChan != nil { return fmt.Errorf("holepunch monitor already running") } if err := pm.holepunchTester.Start(); err != nil { return fmt.Errorf("failed to start holepunch tester: %w", err) } pm.holepunchStopChan = make(chan struct{}) go pm.runHolepunchMonitor() logger.Info("Started holepunch connection monitor") return nil } // stopHolepunchMonitor stops the holepunch connection monitoring func (pm *PeerMonitor) stopHolepunchMonitor() { pm.mutex.Lock() stopChan := pm.holepunchStopChan pm.holepunchStopChan = nil pm.mutex.Unlock() if stopChan != nil { close(stopChan) } if pm.holepunchTester != nil { pm.holepunchTester.Stop() } logger.Info("Stopped holepunch connection monitor") } // runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff func (pm *PeerMonitor) runHolepunchMonitor() { pm.mutex.Lock() pm.holepunchCurrentInterval = pm.holepunchMinInterval pm.mutex.Unlock() timer := time.NewTimer(0) // Fire immediately for initial check defer timer.Stop() for { select { case <-pm.holepunchStopChan: return case <-pm.holepunchUpdateChan: // Interval settings changed, reset to minimum pm.mutex.Lock() pm.holepunchCurrentInterval = pm.holepunchMinInterval currentInterval := pm.holepunchCurrentInterval pm.mutex.Unlock() timer.Reset(currentInterval) logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) case <-timer.C: anyStatusChanged := pm.checkHolepunchEndpoints() pm.mutex.Lock() if anyStatusChanged { // Reset to minimum interval on any status change pm.holepunchCurrentInterval = pm.holepunchMinInterval } else { // Apply exponential backoff when stable newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier) if newInterval > pm.holepunchMaxInterval { newInterval = pm.holepunchMaxInterval } pm.holepunchCurrentInterval = newInterval } currentInterval := pm.holepunchCurrentInterval pm.mutex.Unlock() timer.Reset(currentInterval) } } } // checkHolepunchEndpoints tests all holepunch endpoints // Returns true if any endpoint's status changed func (pm *PeerMonitor) checkHolepunchEndpoints() bool { pm.mutex.Lock() // Check if we're still running before doing any work if !pm.running { pm.mutex.Unlock() return false } endpoints := make(map[int]string, len(pm.holepunchEndpoints)) for siteID, endpoint := range pm.holepunchEndpoints { endpoints[siteID] = endpoint } timeout := pm.holepunchTimeout maxAttempts := pm.holepunchMaxAttempts pm.mutex.Unlock() anyStatusChanged := false for siteID, endpoint := range endpoints { logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() // Check if peer was removed while we were testing if _, stillExists := pm.holepunchEndpoints[siteID]; !stillExists { pm.mutex.Unlock() continue // Peer was removed, skip processing } previousStatus, exists := pm.holepunchStatus[siteID] pm.holepunchStatus[siteID] = result.Success isRelayed := pm.relayedPeers[siteID] // Track consecutive failures for relay triggering if result.Success { pm.holepunchFailures[siteID] = 0 } else { pm.holepunchFailures[siteID]++ } failureCount := pm.holepunchFailures[siteID] pm.mutex.Unlock() // Log status changes statusChanged := !exists || previousStatus != result.Success if statusChanged { anyStatusChanged = true if result.Success { logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT) } else { if result.Error != nil { logger.Warn("Holepunch to site %d (%s) is DISCONNECTED: %v", siteID, endpoint, result.Error) } else { logger.Warn("Holepunch to site %d (%s) is DISCONNECTED", siteID, endpoint) } } } // Update API with holepunch status if pm.apiServer != nil { // Update holepunch connection status pm.apiServer.UpdatePeerHolepunchStatus(siteID, result.Success) // Get the current WG connection status for this peer pm.mutex.Lock() wgConnected := pm.wgConnectionStatus[siteID] pm.mutex.Unlock() // Update API - use holepunch endpoint and relay status pm.apiServer.UpdatePeerStatus(siteID, wgConnected, result.RTT, endpoint, isRelayed) } // Handle relay logic based on holepunch status // Check if we're still running before sending relay messages pm.mutex.Lock() stillRunning := pm.running pm.mutex.Unlock() if !stillRunning { return anyStatusChanged // Stop processing if shutdown is in progress } if !result.Success && !isRelayed && failureCount >= maxAttempts { // Holepunch failed and we're not relayed - trigger relay logger.Info("Holepunch to site %d failed %d times, triggering relay", siteID, failureCount) if pm.wsClient != nil { pm.sendRelay(siteID) } } else if result.Success && isRelayed { // Holepunch succeeded and we ARE relayed - switch back to direct logger.Info("Holepunch to site %d succeeded while relayed, switching to direct connection", siteID) if pm.wsClient != nil { pm.sendUnRelay(siteID) } } } return anyStatusChanged } // GetHolepunchStatus returns the current holepunch status for all endpoints func (pm *PeerMonitor) GetHolepunchStatus() map[int]bool { pm.mutex.Lock() defer pm.mutex.Unlock() status := make(map[int]bool, len(pm.holepunchStatus)) for siteID, connected := range pm.holepunchStatus { status[siteID] = connected } return status } // Close stops monitoring and cleans up resources func (pm *PeerMonitor) Close() { // Stop holepunch monitor first (outside of mutex to avoid deadlock) pm.stopHolepunchMonitor() pm.mutex.Lock() defer pm.mutex.Unlock() logger.Debug("PeerMonitor: Starting cleanup") // Stop and close all clients first for siteID, client := range pm.monitors { logger.Debug("PeerMonitor: Stopping client for site %d", siteID) client.StopMonitor() client.Close() delete(pm.monitors, siteID) } pm.running = false // Clean up netstack resources logger.Debug("PeerMonitor: Cancelling netstack context") if pm.nsCancel != nil { pm.nsCancel() // Signal goroutines to stop } // Close the channel endpoint to unblock any pending reads logger.Debug("PeerMonitor: Closing endpoint") if pm.ep != nil { pm.ep.Close() } // Wait for packet sender goroutine to finish with timeout logger.Debug("PeerMonitor: Waiting for goroutines to finish") done := make(chan struct{}) go func() { pm.nsWg.Wait() close(done) }() select { case <-done: logger.Debug("PeerMonitor: Goroutines finished cleanly") case <-time.After(2 * time.Second): logger.Warn("PeerMonitor: Timeout waiting for goroutines to finish, proceeding anyway") } // Destroy the stack last, after all goroutines are done logger.Debug("PeerMonitor: Destroying stack") if pm.stack != nil { pm.stack.Destroy() pm.stack = nil } logger.Debug("PeerMonitor: Cleanup complete") } // // TestPeer tests connectivity to a specific peer // func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { // pm.mutex.Lock() // client, exists := pm.monitors[siteID] // pm.mutex.Unlock() // if !exists { // return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) // } // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // defer cancel() // connected, rtt := client.TestPeerConnection(ctx) // return connected, rtt, nil // } // // TestAllPeers tests connectivity to all peers // func (pm *PeerMonitor) TestAllPeers() map[int]struct { // Connected bool // RTT time.Duration // } { // pm.mutex.Lock() // peers := make(map[int]*Client, len(pm.monitors)) // for siteID, client := range pm.monitors { // peers[siteID] = client // } // pm.mutex.Unlock() // results := make(map[int]struct { // Connected bool // RTT time.Duration // }) // for siteID, client := range peers { // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // connected, rtt := client.TestPeerConnection(ctx) // cancel() // results[siteID] = struct { // Connected bool // RTT time.Duration // }{ // Connected: connected, // RTT: rtt, // } // } // return results // } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { if pm.localIP == "" { return fmt.Errorf("local IP not provided") } addr, err := netip.ParseAddr(pm.localIP) if err != nil { return fmt.Errorf("invalid local IP: %v", err) } // Create gvisor netstack stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: true, } pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG) pm.stack = stack.New(stackOpts) // Create NIC if err := pm.stack.CreateNIC(1, pm.ep); err != nil { return fmt.Errorf("failed to create NIC: %v", err) } // Add IP address ipBytes := addr.As4() protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(), } if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { return fmt.Errorf("failed to add protocol address: %v", err) } // Add default route pm.stack.AddRoute(tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, }) // Register filter rule on MiddleDevice // We want to intercept packets destined to our local IP // But ONLY if they are for ports we are listening on pm.middleDev.AddRule(addr, pm.handlePacket) // Start packet sender (Stack -> WG) pm.nsWg.Add(1) go pm.runPacketSender() return nil } // handlePacket is called by MiddleDevice when a packet arrives for our IP func (pm *PeerMonitor) handlePacket(packet []byte) bool { // Check if it's UDP proto, ok := util.GetProtocol(packet) if !ok || proto != 17 { // UDP return false } // Check destination port port, ok := util.GetDestPort(packet) if !ok { return false } // Check if we are listening on this port pm.portsLock.RLock() active := pm.activePorts[uint16(port)] pm.portsLock.RUnlock() if !active { return false } // Inject into netstack version := packet[0] >> 4 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), }) switch version { case 4: pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb) case 6: pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb) default: pkb.DecRef() return false } pkb.DecRef() return true // Handled } // runPacketSender reads packets from netstack and injects them into WireGuard func (pm *PeerMonitor) runPacketSender() { defer pm.nsWg.Done() logger.Debug("PeerMonitor: Packet sender goroutine started") for { // Use blocking ReadContext instead of polling - much more CPU efficient // This will block until a packet is available or context is cancelled pkt := pm.ep.ReadContext(pm.nsCtx) if pkt == nil { // Context was cancelled or endpoint closed logger.Debug("PeerMonitor: Packet sender context cancelled, draining packets") // Drain any remaining packets before exiting for { pkt := pm.ep.Read() if pkt == nil { break } pkt.DecRef() } logger.Debug("PeerMonitor: Packet sender goroutine exiting") return } // Extract packet data slices := pkt.AsSlices() if len(slices) > 0 { var totalSize int for _, slice := range slices { totalSize += len(slice) } buf := make([]byte, totalSize) pos := 0 for _, slice := range slices { copy(buf[pos:], slice) pos += len(slice) } // Inject into MiddleDevice (outbound to WG) pm.middleDev.InjectOutbound(buf) } pkt.DecRef() } } // dial creates a UDP connection using the netstack func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) { if pm.stack == nil { return nil, fmt.Errorf("netstack not initialized") } // Parse remote address raddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } // Parse local IP localIP, err := netip.ParseAddr(pm.localIP) if err != nil { return nil, err } ipBytes := localIP.As4() // Create UDP connection // We bind to port 0 (ephemeral) laddr := &tcpip.FullAddress{ NIC: 1, Addr: tcpip.AddrFrom4(ipBytes), Port: 0, } raddrTcpip := &tcpip.FullAddress{ NIC: 1, Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())), Port: uint16(raddr.Port), } conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber) if err != nil { return nil, err } // Get local port localAddr := conn.LocalAddr().(*net.UDPAddr) port := uint16(localAddr.Port) // Register port pm.portsLock.Lock() pm.activePorts[port] = true pm.portsLock.Unlock() // Wrap connection to cleanup port on close return &trackedConn{ Conn: conn, pm: pm, port: port, }, nil } func (pm *PeerMonitor) removePort(port uint16) { pm.portsLock.Lock() delete(pm.activePorts, port) pm.portsLock.Unlock() } type trackedConn struct { net.Conn pm *PeerMonitor port uint16 } func (c *trackedConn) Close() error { c.pm.removePort(c.port) if c.Conn != nil { return c.Conn.Close() } return nil }