Add wgtester

This commit is contained in:
Owen
2025-03-31 18:11:04 -04:00
parent aeb8f203a4
commit e35f7c2d36
3 changed files with 419 additions and 43 deletions

347
wgtester/wgtester.go Normal file
View File

@@ -0,0 +1,347 @@
package wgtester
import (
"context"
"encoding/binary"
"log"
"net"
"sync"
"time"
)
const (
// Magic bytes to identify our packets
magicHeader uint32 = 0xDEADBEEF
// Request packet type
packetTypeRequest uint8 = 1
// Response packet type
packetTypeResponse uint8 = 2
// Packet format:
// - 4 bytes: magic header (0xDEADBEEF)
// - 1 byte: packet type (1 = request, 2 = response)
// - 8 bytes: timestamp (for round-trip timing)
packetSize = 13
)
// Server handles listening for connection check requests
type Server struct {
conn *net.UDPConn
listenAddr string
shutdownCh chan struct{}
isRunning bool
runningLock sync.Mutex
}
// NewServer creates a new connection test server
func NewServer(listenAddr string) *Server {
return &Server{
listenAddr: listenAddr,
shutdownCh: make(chan struct{}),
}
}
// Start begins listening for connection test packets
func (s *Server) Start() error {
s.runningLock.Lock()
defer s.runningLock.Unlock()
if s.isRunning {
return nil
}
addr, err := net.ResolveUDPAddr("udp", s.listenAddr)
if err != nil {
return err
}
s.conn, err = net.ListenUDP("udp", addr)
if err != nil {
return err
}
s.isRunning = true
go s.handleConnections()
log.Printf("Server listening on %s", s.listenAddr)
return nil
}
// Stop shuts down the server
func (s *Server) Stop() {
s.runningLock.Lock()
defer s.runningLock.Unlock()
if !s.isRunning {
return
}
close(s.shutdownCh)
if s.conn != nil {
s.conn.Close()
}
s.isRunning = false
log.Println("Server stopped")
}
// handleConnections processes incoming packets
func (s *Server) handleConnections() {
buffer := make([]byte, packetSize)
for {
select {
case <-s.shutdownCh:
return
default:
// Set read deadline to avoid blocking forever
s.conn.SetReadDeadline(time.Now().Add(1 * time.Second))
n, addr, err := s.conn.ReadFromUDP(buffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Just a timeout, keep going
continue
}
log.Printf("Error reading from UDP: %v", err)
continue
}
if n != packetSize {
continue // Ignore malformed packets
}
// Check magic header
magic := binary.BigEndian.Uint32(buffer[0:4])
if magic != magicHeader {
continue // Not our packet
}
// Check packet type
packetType := buffer[4]
if packetType != packetTypeRequest {
continue // Not a request packet
}
// Keep the timestamp the same (for RTT calculation)
// Just change the packet type to response
buffer[4] = packetTypeResponse
// Send response
_, err = s.conn.WriteToUDP(buffer, addr)
if err != nil {
log.Printf("Error sending response: %v", err)
}
}
}
}
// Client handles checking connectivity to a server
type Client struct {
conn *net.UDPConn
serverAddr string
monitorRunning bool
monitorLock sync.Mutex
shutdownCh chan struct{}
packetInterval time.Duration
timeout time.Duration
maxAttempts int
}
// ConnectionStatus represents the current connection state
type ConnectionStatus struct {
Connected bool
RTT time.Duration
}
// NewClient creates a new connection test client
func NewClient(serverAddr string) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second,
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
}, nil
}
// SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) {
c.packetInterval = interval
}
// SetTimeout changes the timeout for waiting for responses
func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (c *Client) SetMaxAttempts(attempts int) {
c.maxAttempts = attempts
}
// Close cleans up client resources
func (c *Client) Close() {
c.StopMonitor()
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
}
// ensureConnection makes sure we have an active UDP connection
func (c *Client) ensureConnection() error {
if c.conn != nil {
return nil
}
serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
if err != nil {
return err
}
c.conn, err = net.DialUDP("udp", nil, serverAddr)
if err != nil {
return err
}
return nil
}
// TestConnection checks if the connection to the server is working
// Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
if err := c.ensureConnection(); err != nil {
return false, 0
}
// Prepare packet buffer
packet := make([]byte, packetSize)
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
packet[4] = packetTypeRequest
// Send multiple attempts as specified
for attempt := 0; attempt < c.maxAttempts; attempt++ {
select {
case <-ctx.Done():
return false, 0
default:
// Add current timestamp to packet
timestamp := time.Now().UnixNano()
binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp))
// Send the packet
_, err := c.conn.Write(packet)
if err != nil {
log.Printf("Error sending packet: %v", err)
continue
}
// Set read deadline
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
// Wait for response
responseBuffer := make([]byte, packetSize)
n, err := c.conn.Read(responseBuffer)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout, try next attempt
time.Sleep(100 * time.Millisecond) // Brief pause between attempts
continue
}
log.Printf("Error reading response: %v", err)
continue
}
if n != packetSize {
continue // Malformed packet
}
// Verify response
magic := binary.BigEndian.Uint32(responseBuffer[0:4])
packetType := responseBuffer[4]
if magic != magicHeader || packetType != packetTypeResponse {
continue // Not our response
}
// Extract the original timestamp and calculate RTT
sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13]))
rtt := time.Duration(time.Now().UnixNano() - sentTimestamp)
return true, rtt
}
}
return false, 0
}
// TestConnectionWithTimeout tries to test connection with a timeout
// Returns true if connected, false otherwise
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.TestConnection(ctx)
}
// MonitorCallback is the function type for connection status change callbacks
type MonitorCallback func(status ConnectionStatus)
// StartMonitor begins monitoring the connection and calls the callback
// when the connection status changes
func (c *Client) StartMonitor(callback MonitorCallback) error {
c.monitorLock.Lock()
defer c.monitorLock.Unlock()
if c.monitorRunning {
return nil // Already running
}
if err := c.ensureConnection(); err != nil {
return err
}
c.monitorRunning = true
c.shutdownCh = make(chan struct{})
go func() {
var lastConnected bool
firstRun := true
ticker := time.NewTicker(c.packetInterval)
defer ticker.Stop()
for {
select {
case <-c.shutdownCh:
return
case <-ticker.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx)
cancel()
// Callback if status changed or it's the first check
if connected != lastConnected || firstRun {
callback(ConnectionStatus{
Connected: connected,
RTT: rtt,
})
lastConnected = connected
firstRun = false
}
}
}
}()
return nil
}
// StopMonitor stops the connection monitoring
func (c *Client) StopMonitor() {
c.monitorLock.Lock()
defer c.monitorLock.Unlock()
if !c.monitorRunning {
return
}
close(c.shutdownCh)
c.monitorRunning = false
}