package peermonitor import ( "context" "fmt" "net" "net/netip" "strings" "sync" "time" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" middleDevice "github.com/fosrl/olm/device" "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) // PeerMonitorCallback is the function type for connection status change callbacks type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) // WireGuardConfig holds the WireGuard configuration for a peer type WireGuardConfig struct { SiteID int PublicKey string ServerIP string Endpoint string PrimaryRelay string // The primary relay endpoint } // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { monitors map[int]*Client configs map[int]*WireGuardConfig callback PeerMonitorCallback mutex sync.Mutex running bool interval time.Duration timeout time.Duration maxAttempts int privateKey string wsClient *websocket.Client device *device.Device handleRelaySwitch bool // Whether to handle relay switching // Netstack fields middleDev *middleDevice.MiddleDevice localIP string stack *stack.Stack ep *channel.Endpoint activePorts map[uint16]bool portsLock sync.Mutex nsCtx context.Context nsCancel context.CancelFunc nsWg sync.WaitGroup } // NewPeerMonitor creates a new peer monitor with the given callback func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor { ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), configs: make(map[int]*WireGuardConfig), callback: callback, interval: 1 * time.Second, // Default check interval timeout: 2500 * time.Millisecond, maxAttempts: 8, privateKey: privateKey, wsClient: wsClient, device: device, handleRelaySwitch: handleRelaySwitch, middleDev: middleDev, localIP: localIP, activePorts: make(map[uint16]bool), nsCtx: ctx, nsCancel: cancel, } if err := pm.initNetstack(); err != nil { logger.Error("Failed to initialize netstack for peer monitor: %v", err) } return pm } // SetInterval changes how frequently peers are checked func (pm *PeerMonitor) SetInterval(interval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() pm.interval = interval // Update interval for all existing monitors for _, client := range pm.monitors { client.SetPacketInterval(interval) } } // SetTimeout changes the timeout for waiting for responses func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() pm.timeout = timeout // Update timeout for all existing monitors for _, client := range pm.monitors { client.SetTimeout(timeout) } } // SetMaxAttempts changes the maximum number of attempts for TestConnection func (pm *PeerMonitor) SetMaxAttempts(attempts int) { pm.mutex.Lock() defer pm.mutex.Unlock() pm.maxAttempts = attempts // Update max attempts for all existing monitors for _, client := range pm.monitors { client.SetMaxAttempts(attempts) } } // AddPeer adds a new peer to monitor func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, wgConfig *WireGuardConfig) error { pm.mutex.Lock() defer pm.mutex.Unlock() if _, exists := pm.monitors[siteID]; exists { return nil // Already monitoring } // Use our custom dialer that uses netstack client, err := NewClient(endpoint, pm.dial) if err != nil { return err } client.SetPacketInterval(pm.interval) client.SetTimeout(pm.timeout) client.SetMaxAttempts(pm.maxAttempts) pm.monitors[siteID] = client pm.configs[siteID] = wgConfig if pm.running { if err := client.StartMonitor(func(status ConnectionStatus) { pm.handleConnectionStatusChange(siteID, status) }); err != nil { return err } } return nil } // removePeerUnlocked stops monitoring a peer and removes it from the monitor // This function assumes the mutex is already held by the caller func (pm *PeerMonitor) removePeerUnlocked(siteID int) { client, exists := pm.monitors[siteID] if !exists { return } client.StopMonitor() client.Close() delete(pm.monitors, siteID) delete(pm.configs, siteID) } // RemovePeer stops monitoring a peer and removes it from the monitor func (pm *PeerMonitor) RemovePeer(siteID int) { pm.mutex.Lock() defer pm.mutex.Unlock() pm.removePeerUnlocked(siteID) } // Start begins monitoring all peers func (pm *PeerMonitor) Start() { pm.mutex.Lock() defer pm.mutex.Unlock() if pm.running { return // Already running } pm.running = true // Start monitoring all peers for siteID, client := range pm.monitors { siteIDCopy := siteID // Create a copy for the closure err := client.StartMonitor(func(status ConnectionStatus) { pm.handleConnectionStatusChange(siteIDCopy, status) }) if err != nil { logger.Error("Failed to start monitoring peer %d: %v\n", siteID, err) continue } logger.Info("Started monitoring peer %d\n", siteID) } } // handleConnectionStatusChange is called when a peer's connection status changes func (pm *PeerMonitor) handleConnectionStatusChange(siteID int, status ConnectionStatus) { // Call the user-provided callback first if pm.callback != nil { pm.callback(siteID, status.Connected, status.RTT) } // If disconnected, handle failover if !status.Connected { // Send relay message to the server if pm.wsClient != nil { pm.sendRelay(siteID) } } } // handleFailover handles failover to the relay server when a peer is disconnected func (pm *PeerMonitor) HandleFailover(siteID int, relayEndpoint string) { pm.mutex.Lock() config, exists := pm.configs[siteID] pm.mutex.Unlock() if !exists { return } // Check for IPv6 and format the endpoint correctly formattedEndpoint := relayEndpoint if strings.Contains(relayEndpoint, ":") { formattedEndpoint = fmt.Sprintf("[%s]", relayEndpoint) } // Configure WireGuard to use the relay wgConfig := fmt.Sprintf(`private_key=%s public_key=%s allowed_ip=%s/32 endpoint=%s:21820 persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.ServerIP, formattedEndpoint) err := pm.device.IpcSet(wgConfig) if err != nil { logger.Error("Failed to configure WireGuard device: %v\n", err) return } logger.Info("Adjusted peer %d to point to relay!\n", siteID) } // sendRelay sends a relay message to the server func (pm *PeerMonitor) sendRelay(siteID int) error { if !pm.handleRelaySwitch { return nil } if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } err := pm.wsClient.SendMessage("olm/wg/relay", map[string]interface{}{ "siteId": siteID, }) if err != nil { logger.Error("Failed to send registration message: %v", err) return err } logger.Info("Sent relay message") return nil } // Stop stops monitoring all peers func (pm *PeerMonitor) Stop() { pm.mutex.Lock() defer pm.mutex.Unlock() if !pm.running { return } pm.running = false // Stop all monitors for _, client := range pm.monitors { client.StopMonitor() } } // Close stops monitoring and cleans up resources func (pm *PeerMonitor) Close() { pm.mutex.Lock() defer pm.mutex.Unlock() // Stop and close all clients for siteID, client := range pm.monitors { client.StopMonitor() client.Close() delete(pm.monitors, siteID) } pm.running = false } // TestPeer tests connectivity to a specific peer func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { pm.mutex.Lock() client, exists := pm.monitors[siteID] pm.mutex.Unlock() if !exists { return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) } ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) defer cancel() connected, rtt := client.TestConnection(ctx) return connected, rtt, nil } // TestAllPeers tests connectivity to all peers func (pm *PeerMonitor) TestAllPeers() map[int]struct { Connected bool RTT time.Duration } { pm.mutex.Lock() peers := make(map[int]*Client, len(pm.monitors)) for siteID, client := range pm.monitors { peers[siteID] = client } pm.mutex.Unlock() results := make(map[int]struct { Connected bool RTT time.Duration }) for siteID, client := range peers { ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) connected, rtt := client.TestConnection(ctx) cancel() results[siteID] = struct { Connected bool RTT time.Duration }{ Connected: connected, RTT: rtt, } } return results } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { if pm.localIP == "" { return fmt.Errorf("local IP not provided") } addr, err := netip.ParseAddr(pm.localIP) if err != nil { return fmt.Errorf("invalid local IP: %v", err) } // Create gvisor netstack stackOpts := stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, HandleLocal: true, } pm.ep = channel.New(256, 1420, "") // MTU 1420 (standard WG) pm.stack = stack.New(stackOpts) // Create NIC if err := pm.stack.CreateNIC(1, pm.ep); err != nil { return fmt.Errorf("failed to create NIC: %v", err) } // Add IP address ipBytes := addr.As4() protoAddr := tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddrFrom4(ipBytes).WithPrefix(), } if err := pm.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{}); err != nil { return fmt.Errorf("failed to add protocol address: %v", err) } // Add default route pm.stack.AddRoute(tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, }) // Register filter rule on MiddleDevice // We want to intercept packets destined to our local IP // But ONLY if they are for ports we are listening on pm.middleDev.AddRule(addr, pm.handlePacket) // Start packet sender (Stack -> WG) pm.nsWg.Add(1) go pm.runPacketSender() return nil } // handlePacket is called by MiddleDevice when a packet arrives for our IP func (pm *PeerMonitor) handlePacket(packet []byte) bool { // Check if it's UDP proto, ok := util.GetProtocol(packet) if !ok || proto != 17 { // UDP return false } // Check destination port port, ok := util.GetDestPort(packet) if !ok { return false } // Check if we are listening on this port pm.portsLock.Lock() active := pm.activePorts[uint16(port)] pm.portsLock.Unlock() if !active { return false } // Inject into netstack version := packet[0] >> 4 pkb := stack.NewPacketBuffer(stack.PacketBufferOptions{ Payload: buffer.MakeWithData(packet), }) switch version { case 4: pm.ep.InjectInbound(ipv4.ProtocolNumber, pkb) case 6: pm.ep.InjectInbound(ipv6.ProtocolNumber, pkb) default: pkb.DecRef() return false } pkb.DecRef() return true // Handled } // runPacketSender reads packets from netstack and injects them into WireGuard func (pm *PeerMonitor) runPacketSender() { defer pm.nsWg.Done() for { select { case <-pm.nsCtx.Done(): return default: } pkt := pm.ep.Read() if pkt == nil { time.Sleep(1 * time.Millisecond) continue } // Extract packet data slices := pkt.AsSlices() if len(slices) > 0 { var totalSize int for _, slice := range slices { totalSize += len(slice) } buf := make([]byte, totalSize) pos := 0 for _, slice := range slices { copy(buf[pos:], slice) pos += len(slice) } // Inject into MiddleDevice (outbound to WG) pm.middleDev.InjectOutbound(buf) } pkt.DecRef() } } // dial creates a UDP connection using the netstack func (pm *PeerMonitor) dial(network, addr string) (net.Conn, error) { if pm.stack == nil { return nil, fmt.Errorf("netstack not initialized") } // Parse remote address raddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } // Parse local IP localIP, err := netip.ParseAddr(pm.localIP) if err != nil { return nil, err } ipBytes := localIP.As4() // Create UDP connection // We bind to port 0 (ephemeral) laddr := &tcpip.FullAddress{ NIC: 1, Addr: tcpip.AddrFrom4(ipBytes), Port: 0, } raddrTcpip := &tcpip.FullAddress{ NIC: 1, Addr: tcpip.AddrFrom4([4]byte(raddr.IP.To4())), Port: uint16(raddr.Port), } conn, err := gonet.DialUDP(pm.stack, laddr, raddrTcpip, ipv4.ProtocolNumber) if err != nil { return nil, err } // Get local port localAddr := conn.LocalAddr().(*net.UDPAddr) port := uint16(localAddr.Port) // Register port pm.portsLock.Lock() pm.activePorts[port] = true pm.portsLock.Unlock() // Wrap connection to cleanup port on close return &trackedConn{ Conn: conn, pm: pm, port: port, }, nil } func (pm *PeerMonitor) removePort(port uint16) { pm.portsLock.Lock() delete(pm.activePorts, port) pm.portsLock.Unlock() } type trackedConn struct { net.Conn pm *PeerMonitor port uint16 } func (c *trackedConn) Close() error { c.pm.removePort(c.port) return c.Conn.Close() }