From e35f7c2d36dd93553bbe31d0ef85a7934d061035 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 31 Mar 2025 18:11:04 -0400 Subject: [PATCH] Add wgtester --- common.go | 14 +- main.go | 101 ++++++++----- wgtester/wgtester.go | 347 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 419 insertions(+), 43 deletions(-) create mode 100644 wgtester/wgtester.go diff --git a/common.go b/common.go index 7bb14d9..720c90e 100644 --- a/common.go +++ b/common.go @@ -21,11 +21,15 @@ import ( ) type WgData struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` - ServerIP string `json:"serverIP"` - TunnelIP string `json:"tunnelIP"` - Targets TargetsByType `json:"targets"` + Sites []SiteConfig `json:"sites"` + TunnelIP string `json:"tunnelIP"` +} + +type SiteConfig struct { + SiteId string `json:"siteId"` + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` + ServerIP string `json:"serverIP"` } type TargetsByType struct { diff --git a/main.go b/main.go index 251991f..559c354 100644 --- a/main.go +++ b/main.go @@ -13,6 +13,7 @@ import ( "regexp" "runtime" "strconv" + "strings" "syscall" "github.com/fosrl/newt/logger" @@ -235,18 +236,30 @@ func main() { return } - endpoint, err := resolveDomain(wgData.Endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return + // Configure WireGuard with all sites as peers + var configBuilder strings.Builder + + // Start with the private key + configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String()))) + + // Add each site as a peer + for _, site := range wgData.Sites { + siteHost, err := resolveDomain(site.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint for site %s: %v", site.SiteId, err) + continue + } + + // Include peer info + configBuilder.WriteString(fmt.Sprintf("\n# Site %s\n", site.SiteId)) + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(site.PublicKey))) + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s/32\n", site.ServerIP)) + configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) + configBuilder.WriteString("persistent_keepalive_interval=1\n") } - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s - public_key=%s - allowed_ip=%s/32 - endpoint=%s - persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, endpoint) + config := configBuilder.String() + logger.Debug("WireGuard config: %s", config) err = dev.IpcSet(config) if err != nil { @@ -377,18 +390,30 @@ func main() { logger.Info("UAPI listener started") - host, err := resolveDomain(wgData.Endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return + // Configure WireGuard with all sites as peers + var configBuilder strings.Builder + + // Start with the private key + configBuilder.WriteString(fmt.Sprintf("private_key=%s\n", fixKey(privateKey.String()))) + + // Add each site as a peer + for _, site := range wgData.Sites { + siteHost, err := resolveDomain(site.Endpoint) + if err != nil { + logger.Warn("Failed to resolve endpoint for site %s: %v", site.SiteId, err) + continue + } + + // Include peer info + configBuilder.WriteString(fmt.Sprintf("\n# Site %s\n", site.SiteId)) + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", fixKey(site.PublicKey))) + configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s/32\n", site.ServerIP)) + configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) + configBuilder.WriteString("persistent_keepalive_interval=1\n") } - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s -public_key=%s -allowed_ip=%s/32 -endpoint=%s -persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) + config := configBuilder.String() + logger.Debug("WireGuard config: %s", config) err = dev.IpcSet(config) if err != nil { @@ -410,28 +435,28 @@ persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.Pub close(stopHolepunch) // Monitor the connection for activity - monitorConnection(dev, func() { - host, err := resolveDomain(endpoint) - if err != nil { - logger.Error("Failed to resolve endpoint: %v", err) - return - } + monitorConnection(dev, func() { // TODO: this now has to be per site + // host, err := resolveDomain(endpoint) + // if err != nil { + // logger.Error("Failed to resolve endpoint: %v", err) + // return + // } - // Configure WireGuard - config := fmt.Sprintf(`private_key=%s -public_key=%s -allowed_ip=%s/32 -endpoint=%s:21820 -persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) + // // Configure WireGuard + // config := fmt.Sprintf(`private_key=%s + // public_key=%s + // allowed_ip=%s/32 + // endpoint=%s:21820 + // persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) - err = dev.IpcSet(config) - if err != nil { - logger.Error("Failed to configure WireGuard device: %v", err) - } + // err = dev.IpcSet(config) + // if err != nil { + // logger.Error("Failed to configure WireGuard device: %v", err) + // } - logger.Info("Adjusted to point to relay!") + // logger.Info("Adjusted to point to relay!") - sendRelay(olm) + // sendRelay(olm) }) logger.Info("WireGuard device created.") diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go new file mode 100644 index 0000000..cefe785 --- /dev/null +++ b/wgtester/wgtester.go @@ -0,0 +1,347 @@ +package wgtester + +import ( + "context" + "encoding/binary" + "log" + "net" + "sync" + "time" +) + +const ( + // Magic bytes to identify our packets + magicHeader uint32 = 0xDEADBEEF + // Request packet type + packetTypeRequest uint8 = 1 + // Response packet type + packetTypeResponse uint8 = 2 + // Packet format: + // - 4 bytes: magic header (0xDEADBEEF) + // - 1 byte: packet type (1 = request, 2 = response) + // - 8 bytes: timestamp (for round-trip timing) + packetSize = 13 +) + +// Server handles listening for connection check requests +type Server struct { + conn *net.UDPConn + listenAddr string + shutdownCh chan struct{} + isRunning bool + runningLock sync.Mutex +} + +// NewServer creates a new connection test server +func NewServer(listenAddr string) *Server { + return &Server{ + listenAddr: listenAddr, + shutdownCh: make(chan struct{}), + } +} + +// Start begins listening for connection test packets +func (s *Server) Start() error { + s.runningLock.Lock() + defer s.runningLock.Unlock() + + if s.isRunning { + return nil + } + + addr, err := net.ResolveUDPAddr("udp", s.listenAddr) + if err != nil { + return err + } + + s.conn, err = net.ListenUDP("udp", addr) + if err != nil { + return err + } + + s.isRunning = true + go s.handleConnections() + + log.Printf("Server listening on %s", s.listenAddr) + return nil +} + +// Stop shuts down the server +func (s *Server) Stop() { + s.runningLock.Lock() + defer s.runningLock.Unlock() + + if !s.isRunning { + return + } + + close(s.shutdownCh) + if s.conn != nil { + s.conn.Close() + } + s.isRunning = false + log.Println("Server stopped") +} + +// handleConnections processes incoming packets +func (s *Server) handleConnections() { + buffer := make([]byte, packetSize) + + for { + select { + case <-s.shutdownCh: + return + default: + // Set read deadline to avoid blocking forever + s.conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + + n, addr, err := s.conn.ReadFromUDP(buffer) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Just a timeout, keep going + continue + } + log.Printf("Error reading from UDP: %v", err) + continue + } + + if n != packetSize { + continue // Ignore malformed packets + } + + // Check magic header + magic := binary.BigEndian.Uint32(buffer[0:4]) + if magic != magicHeader { + continue // Not our packet + } + + // Check packet type + packetType := buffer[4] + if packetType != packetTypeRequest { + continue // Not a request packet + } + + // Keep the timestamp the same (for RTT calculation) + // Just change the packet type to response + buffer[4] = packetTypeResponse + + // Send response + _, err = s.conn.WriteToUDP(buffer, addr) + if err != nil { + log.Printf("Error sending response: %v", err) + } + } + } +} + +// Client handles checking connectivity to a server +type Client struct { + conn *net.UDPConn + serverAddr string + monitorRunning bool + monitorLock sync.Mutex + shutdownCh chan struct{} + packetInterval time.Duration + timeout time.Duration + maxAttempts int +} + +// ConnectionStatus represents the current connection state +type ConnectionStatus struct { + Connected bool + RTT time.Duration +} + +// NewClient creates a new connection test client +func NewClient(serverAddr string) (*Client, error) { + return &Client{ + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + packetInterval: 2 * time.Second, + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + }, nil +} + +// SetPacketInterval changes how frequently packets are sent in monitor mode +func (c *Client) SetPacketInterval(interval time.Duration) { + c.packetInterval = interval +} + +// SetTimeout changes the timeout for waiting for responses +func (c *Client) SetTimeout(timeout time.Duration) { + c.timeout = timeout +} + +// SetMaxAttempts changes the maximum number of attempts for TestConnection +func (c *Client) SetMaxAttempts(attempts int) { + c.maxAttempts = attempts +} + +// Close cleans up client resources +func (c *Client) Close() { + c.StopMonitor() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} + +// ensureConnection makes sure we have an active UDP connection +func (c *Client) ensureConnection() error { + if c.conn != nil { + return nil + } + + serverAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) + if err != nil { + return err + } + + c.conn, err = net.DialUDP("udp", nil, serverAddr) + if err != nil { + return err + } + + return nil +} + +// TestConnection checks if the connection to the server is working +// Returns true if connected, false otherwise +func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { + if err := c.ensureConnection(); err != nil { + return false, 0 + } + + // Prepare packet buffer + packet := make([]byte, packetSize) + binary.BigEndian.PutUint32(packet[0:4], magicHeader) + packet[4] = packetTypeRequest + + // Send multiple attempts as specified + for attempt := 0; attempt < c.maxAttempts; attempt++ { + select { + case <-ctx.Done(): + return false, 0 + default: + // Add current timestamp to packet + timestamp := time.Now().UnixNano() + binary.BigEndian.PutUint64(packet[5:13], uint64(timestamp)) + + // Send the packet + _, err := c.conn.Write(packet) + if err != nil { + log.Printf("Error sending packet: %v", err) + continue + } + + // Set read deadline + c.conn.SetReadDeadline(time.Now().Add(c.timeout)) + + // Wait for response + responseBuffer := make([]byte, packetSize) + n, err := c.conn.Read(responseBuffer) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Timeout, try next attempt + time.Sleep(100 * time.Millisecond) // Brief pause between attempts + continue + } + log.Printf("Error reading response: %v", err) + continue + } + + if n != packetSize { + continue // Malformed packet + } + + // Verify response + magic := binary.BigEndian.Uint32(responseBuffer[0:4]) + packetType := responseBuffer[4] + if magic != magicHeader || packetType != packetTypeResponse { + continue // Not our response + } + + // Extract the original timestamp and calculate RTT + sentTimestamp := int64(binary.BigEndian.Uint64(responseBuffer[5:13])) + rtt := time.Duration(time.Now().UnixNano() - sentTimestamp) + + return true, rtt + } + } + + return false, 0 +} + +// TestConnectionWithTimeout tries to test connection with a timeout +// Returns true if connected, false otherwise +func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return c.TestConnection(ctx) +} + +// MonitorCallback is the function type for connection status change callbacks +type MonitorCallback func(status ConnectionStatus) + +// StartMonitor begins monitoring the connection and calls the callback +// when the connection status changes +func (c *Client) StartMonitor(callback MonitorCallback) error { + c.monitorLock.Lock() + defer c.monitorLock.Unlock() + + if c.monitorRunning { + return nil // Already running + } + + if err := c.ensureConnection(); err != nil { + return err + } + + c.monitorRunning = true + c.shutdownCh = make(chan struct{}) + + go func() { + var lastConnected bool + firstRun := true + + ticker := time.NewTicker(c.packetInterval) + defer ticker.Stop() + + for { + select { + case <-c.shutdownCh: + return + case <-ticker.C: + ctx, cancel := context.WithTimeout(context.Background(), c.timeout) + connected, rtt := c.TestConnection(ctx) + cancel() + + // Callback if status changed or it's the first check + if connected != lastConnected || firstRun { + callback(ConnectionStatus{ + Connected: connected, + RTT: rtt, + }) + lastConnected = connected + firstRun = false + } + } + } + }() + + return nil +} + +// StopMonitor stops the connection monitoring +func (c *Client) StopMonitor() { + c.monitorLock.Lock() + defer c.monitorLock.Unlock() + + if !c.monitorRunning { + return + } + + close(c.shutdownCh) + c.monitorRunning = false +}