package monitor import ( "context" "encoding/binary" "net" "sync" "time" "github.com/fosrl/newt/logger" ) const ( // Magic bytes to identify our packets magicHeader uint32 = 0xDEADBEEF // Request packet type packetTypeRequest uint8 = 1 // Response packet type packetTypeResponse uint8 = 2 // Packet format: // - 4 bytes: magic header (0xDEADBEEF) // - 1 byte: packet type (1 = request, 2 = response) // - 8 bytes: timestamp (for round-trip timing) packetSize = 13 ) // Client handles checking connectivity to a server type Client struct { conn net.Conn serverAddr string monitorRunning bool monitorLock sync.Mutex connLock sync.Mutex // Protects connection operations shutdownCh chan struct{} updateCh chan struct{} packetInterval time.Duration timeout time.Duration maxAttempts int dialer Dialer // Exponential backoff fields defaultMinInterval time.Duration // Default minimum interval (initial) defaultMaxInterval time.Duration // Default maximum interval (cap for backoff) minInterval time.Duration // Minimum interval (initial) maxInterval time.Duration // Maximum interval (cap for backoff) backoffMultiplier float64 // Multiplier for each stable check stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection type Dialer func(network, addr string) (net.Conn, error) // ConnectionStatus represents the current connection state type ConnectionStatus struct { Connected bool RTT time.Duration } // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ serverAddr: serverAddr, shutdownCh: make(chan struct{}), updateCh: make(chan struct{}, 1), packetInterval: 2 * time.Second, defaultMinInterval: 2 * time.Second, defaultMaxInterval: 30 * time.Second, minInterval: 2 * time.Second, maxInterval: 30 * time.Second, backoffMultiplier: 1.5, stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off timeout: 500 * time.Millisecond, // Timeout for individual packets maxAttempts: 3, // Default max attempts dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) { c.monitorLock.Lock() c.packetInterval = minInterval c.minInterval = minInterval c.maxInterval = maxInterval updateCh := c.updateCh monitorRunning := c.monitorRunning c.monitorLock.Unlock() // Signal the goroutine to apply the new interval if running if monitorRunning && updateCh != nil { select { case updateCh <- struct{}{}: default: // Channel full or closed, skip } } } func (c *Client) ResetPacketInterval() { c.monitorLock.Lock() c.packetInterval = c.defaultMinInterval c.minInterval = c.defaultMinInterval c.maxInterval = c.defaultMaxInterval updateCh := c.updateCh monitorRunning := c.monitorRunning c.monitorLock.Unlock() // Signal the goroutine to apply the new interval if running if monitorRunning && updateCh != nil { select { case updateCh <- struct{}{}: default: // Channel full or closed, skip } } } // UpdateServerAddr updates the server address and resets the connection func (c *Client) UpdateServerAddr(serverAddr string) { c.connLock.Lock() defer c.connLock.Unlock() // Close existing connection if any if c.conn != nil { c.conn.Close() c.conn = nil } c.serverAddr = serverAddr } // Close cleans up client resources func (c *Client) Close() { c.StopMonitor() c.connLock.Lock() defer c.connLock.Unlock() if c.conn != nil { c.conn.Close() c.conn = nil } } // ensureConnection makes sure we have an active UDP connection func (c *Client) ensureConnection() error { c.connLock.Lock() defer c.connLock.Unlock() if c.conn != nil { return nil } var err error if c.dialer != nil { c.conn, err = c.dialer("udp", c.serverAddr) } else { // Fallback to standard net.Dial c.conn, err = net.Dial("udp", c.serverAddr) } if err != nil { return err } return nil } // TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 } // Prepare packet buffer packet := make([]byte, packetSize) binary.BigEndian.PutUint32(packet[0:4], magicHeader) packet[4] = packetTypeRequest // Reusable response buffer responseBuffer := make([]byte, packetSize) // Send multiple attempts as specified for attempt := 0; attempt < c.maxAttempts; attempt++ { select { case <-ctx.Done(): return false, 0 default: // Add current timestamp to packet timestamp := time.Now().UnixNano() binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp)) // Lock the connection for the entire send/receive operation c.connLock.Lock() // Check if connection is still valid after acquiring lock if c.conn == nil { c.connLock.Unlock() return false, 0 } _, err := c.conn.Write(packet) if err != nil { c.connLock.Unlock() logger.Info("Error sending packet: %v", err) continue } // Set read deadline c.conn.SetReadDeadline(time.Now().Add(c.timeout)) // Wait for response n, err := c.conn.Read(responseBuffer) c.connLock.Unlock() if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { // Timeout, try next attempt time.Sleep(100 * time.Millisecond) // Brief pause between attempts continue } logger.Error("Error reading response: %v", err) continue } if n != packetSize { continue // Malformed packet } // Verify response magic := binary.BigEndian.Uint32(responseBuffer[0:4]) packetType := responseBuffer[4] if magic != magicHeader || packetType != packetTypeResponse { continue // Not our response } // Extract the original timestamp and calculate RTT sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13])) rtt := time.Duration(time.Now().UnixNano() - sentTimestamp) return true, rtt } } return false, 0 } // TestConnectionWithTimeout tries to test connection with a timeout // Returns true if connected, false otherwise func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() return c.TestPeerConnection(ctx) } // MonitorCallback is the function type for connection status change callbacks type MonitorCallback func(status ConnectionStatus) // StartMonitor begins monitoring the connection and calls the callback // when the connection status changes func (c *Client) StartMonitor(callback MonitorCallback) error { c.monitorLock.Lock() defer c.monitorLock.Unlock() if c.monitorRunning { logger.Info("Monitor already running") return nil // Already running } if err := c.ensureConnection(); err != nil { return err } c.monitorRunning = true c.shutdownCh = make(chan struct{}) go func() { var lastConnected bool firstRun := true stableCount := 0 currentInterval := c.minInterval timer := time.NewTimer(currentInterval) defer timer.Stop() for { select { case <-c.shutdownCh: return case <-c.updateCh: // Interval settings changed, reset to minimum c.monitorLock.Lock() currentInterval = c.minInterval c.monitorLock.Unlock() // Reset backoff state stableCount = 0 timer.Reset(currentInterval) logger.Debug("Packet interval updated, reset to %v", currentInterval) case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) connected, rtt := c.TestPeerConnection(ctx) cancel() statusChanged := connected != lastConnected // Callback if status changed or it's the first check if statusChanged || firstRun { callback(ConnectionStatus{ Connected: connected, RTT: rtt, }) lastConnected = connected firstRun = false // Reset backoff on status change stableCount = 0 currentInterval = c.minInterval } else { // Status is stable, increment counter stableCount++ // Apply exponential backoff after stable threshold if stableCount >= c.stableCountToBackoff { newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier) if newInterval > c.maxInterval { newInterval = c.maxInterval } currentInterval = newInterval } } // Reset timer with current interval timer.Reset(currentInterval) } } }() return nil } // StopMonitor stops the connection monitoring func (c *Client) StopMonitor() { c.monitorLock.Lock() defer c.monitorLock.Unlock() if !c.monitorRunning { return } close(c.shutdownCh) c.monitorRunning = false }