From 5d305f1d03a00f521f6ed5cb906e3b1102597045 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 8 Apr 2025 21:57:53 -0400 Subject: [PATCH] Add peer monitor --- common.go | 116 ++++++------------- main.go | 44 ++++--- peermonitor/peermonitor.go | 232 +++++++++++++++++++++++++++++++++++++ 3 files changed, 286 insertions(+), 106 deletions(-) create mode 100644 peermonitor/peermonitor.go diff --git a/common.go b/common.go index fd9c0d6..6e777d3 100644 --- a/common.go +++ b/common.go @@ -6,11 +6,11 @@ import ( "encoding/json" "fmt" "net" - "strconv" "strings" "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/curve25519" @@ -56,10 +56,12 @@ type EncryptedHolePunchMessage struct { } var ( + peerMonitor *peermonitor.PeerMonitor stopHolepunch chan struct{} stopRegister chan struct{} olmToken string gerbilServerPubKey string + peerStatusMap map[int]bool ) const ( @@ -358,87 +360,6 @@ func keepSendingRegistration(olm *websocket.Client, publicKey string) { } } -func monitorConnection(dev *device.Device, onTimeout func()) { - const ( - checkInterval = 100 * time.Millisecond // Check every 0.1 seconds - timeout = 500 * time.Millisecond // Total timeout of 1.5 seconds - ) - - go func() { - ticker := time.NewTicker(checkInterval) - defer ticker.Stop() - - timeoutTimer := time.NewTimer(timeout) - defer timeoutTimer.Stop() - - // var lastSent uint64 - - for { - select { - case <-ticker.C: - // Get the current device statistics - deviceInfo, err := dev.IpcGet() - if err != nil { - logger.Error("Failed to get device statistics: %v", err) - continue - } - - // Parse the statistics from the IPC output - stats := parseStatistics(deviceInfo) - - logger.Info("Received: %d, Sent: %d", stats.received, stats.sent) - - // Check if we've received any new bytes - if stats.received > 0 { - // Connection is successful, we received data - logger.Info("Connection established - received bytes detected") - return - } - - // Update the last known values - // lastSent = stats.sent - - case <-timeoutTimer.C: - // We've hit our timeout without seeing any received bytes - logger.Warn("Connection timeout - no data received within %v", timeout) - onTimeout() - return - } - } - }() -} - -// statistics holds the parsed byte counts from the device -type statistics struct { - received uint64 - sent uint64 -} - -// parseStatistics extracts the received and sent byte counts from the device info string -func parseStatistics(info string) statistics { - var stats statistics - - // Split the device info into lines - lines := strings.Split(info, "\n") - - // Look for the transfer_receive and transfer_send lines - for _, line := range lines { - if strings.HasPrefix(line, "rx_bytes=") { - valueStr := strings.TrimPrefix(line, "rx_bytes=") - if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil { - stats.received = value - } - } else if strings.HasPrefix(line, "tx_bytes=") { - valueStr := strings.TrimPrefix(line, "tx_bytes=") - if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil { - stats.sent = value - } - } - } - - return stats -} - func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { if maxPort < minPort { return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) @@ -474,3 +395,34 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort) } + +func handlePeerStatusChange(siteID int, connected bool, rtt time.Duration) { + // Check if status has changed + prevStatus, exists := peerStatusMap[siteID] + if !exists || prevStatus != connected { + if connected { + logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt) + // Add any actions you want to take when a peer connects + + // Example: try to send a relay message if this is the first peer to connect + if !prevStatus && !exists { + // This is a new connection, not just a status update + go func() { + // Give wireguard a moment to establish properly + // time.Sleep(500 * time.Millisecond) + // if olm != nil { + // if err := sendRelay(olm); err != nil { + // logger.Error("Failed to send relay message: %v", err) + // } + // } + }() + } + } else { + logger.Warn("Peer %d is disconnected", siteID) + // Add any actions you want to take when a peer disconnects + } + + // Update status map + peerStatusMap[siteID] = connected + } +} diff --git a/main.go b/main.go index 9ef29e9..489b831 100644 --- a/main.go +++ b/main.go @@ -15,8 +15,10 @@ import ( "strconv" "strings" "syscall" + "time" "github.com/fosrl/newt/logger" + "github.com/fosrl/olm/peermonitor" "github.com/fosrl/olm/websocket" "github.com/vishvananda/netlink" @@ -133,6 +135,16 @@ func main() { stopHolepunch = make(chan struct{}) stopRegister = make(chan struct{}) + peerStatusMap = make(map[int]bool) + + // Initialize the peer monitor + peerMonitor = peermonitor.NewPeerMonitor(handlePeerStatusChange) + defer peerMonitor.Close() + + // Set custom monitoring parameters if needed + peerMonitor.SetInterval(5 * time.Second) + peerMonitor.SetTimeout(500 * time.Millisecond) + peerMonitor.SetMaxAttempts(3) // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values endpoint = os.Getenv("PANGOLIN_ENDPOINT") @@ -382,6 +394,13 @@ func main() { configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIpStr)) configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) configBuilder.WriteString("persistent_keepalive_interval=1\n") + + err = peerMonitor.AddPeer(site.SiteId, siteHost) + if err != nil { + logger.Warn("Failed to setup monitoring for site %d: %v", site.SiteId, err) + } else { + logger.Info("Started monitoring for site %d at %s", site.SiteId, siteHost) + } } config := configBuilder.String() @@ -406,30 +425,7 @@ func main() { close(stopHolepunch) - // Monitor the connection for activity - monitorConnection(dev, func() { // TODO: this now has to be per site - // host, err := resolveDomain(endpoint) - // if err != nil { - // logger.Error("Failed to resolve endpoint: %v", err) - // return - // } - - // // Configure WireGuard - // config := fmt.Sprintf(`private_key=%s - // public_key=%s - // allowed_ip=%s/32 - // endpoint=%s:21820 - // persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host) - - // err = dev.IpcSet(config) - // if err != nil { - // logger.Error("Failed to configure WireGuard device: %v", err) - // } - - // logger.Info("Adjusted to point to relay!") - - // sendRelay(olm) - }) + peerMonitor.Start() logger.Info("WireGuard device created.") }) diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go new file mode 100644 index 0000000..665b303 --- /dev/null +++ b/peermonitor/peermonitor.go @@ -0,0 +1,232 @@ +package peermonitor + +import ( + "context" + "fmt" + "net" + "sync" + "time" + + "github.com/fosrl/olm/wgtester" +) + +// PeerMonitorCallback is the function type for connection status change callbacks +type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration) + +// PeerMonitor handles monitoring the connection status to multiple WireGuard peers +type PeerMonitor struct { + monitors map[int]*wgtester.Client + callback PeerMonitorCallback + mutex sync.Mutex + running bool + interval time.Duration + timeout time.Duration + maxAttempts int +} + +// NewPeerMonitor creates a new peer monitor with the given callback +func NewPeerMonitor(callback PeerMonitorCallback) *PeerMonitor { + return &PeerMonitor{ + monitors: make(map[int]*wgtester.Client), + callback: callback, + interval: 5 * time.Second, // Default check interval + timeout: 500 * time.Millisecond, + maxAttempts: 3, + } +} + +// SetInterval changes how frequently peers are checked +func (pm *PeerMonitor) SetInterval(interval time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.interval = interval + + // Update interval for all existing monitors + for _, client := range pm.monitors { + client.SetPacketInterval(interval) + } +} + +// SetTimeout changes the timeout for waiting for responses +func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.timeout = timeout + + // Update timeout for all existing monitors + for _, client := range pm.monitors { + client.SetTimeout(timeout) + } +} + +// SetMaxAttempts changes the maximum number of attempts for TestConnection +func (pm *PeerMonitor) SetMaxAttempts(attempts int) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.maxAttempts = attempts + + // Update max attempts for all existing monitors + for _, client := range pm.monitors { + client.SetMaxAttempts(attempts) + } +} + +// AddPeer adds a new peer to monitor +func (pm *PeerMonitor) AddPeer(siteID int, endpoint string) error { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + // Check if we're already monitoring this peer + if _, exists := pm.monitors[siteID]; exists { + // Update the endpoint instead of creating a new monitor + pm.RemovePeer(siteID) + } + + // Add UDP port if not present, assuming default WireGuard port + if _, _, err := net.SplitHostPort(endpoint); err != nil { + endpoint = endpoint + ":51820" // Default WireGuard port + } + + client, err := wgtester.NewClient(endpoint) + if err != nil { + return err + } + + // Configure the client with our settings + client.SetPacketInterval(pm.interval) + client.SetTimeout(pm.timeout) + client.SetMaxAttempts(pm.maxAttempts) + + // Store the client + pm.monitors[siteID] = client + + // If monitor is already running, start monitoring this peer + if pm.running { + siteIDCopy := siteID // Create a copy for the closure + err = client.StartMonitor(func(status wgtester.ConnectionStatus) { + pm.callback(siteIDCopy, status.Connected, status.RTT) + }) + } + + return err +} + +// RemovePeer stops monitoring a peer and removes it from the monitor +func (pm *PeerMonitor) RemovePeer(siteID int) { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + client, exists := pm.monitors[siteID] + if !exists { + return + } + + client.StopMonitor() + client.Close() + delete(pm.monitors, siteID) +} + +// Start begins monitoring all peers +func (pm *PeerMonitor) Start() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + if pm.running { + return // Already running + } + + pm.running = true + + // Start monitoring all peers + for siteID, client := range pm.monitors { + siteIDCopy := siteID // Create a copy for the closure + client.StartMonitor(func(status wgtester.ConnectionStatus) { + pm.callback(siteIDCopy, status.Connected, status.RTT) + }) + } +} + +// Stop stops monitoring all peers +func (pm *PeerMonitor) Stop() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + if !pm.running { + return + } + + pm.running = false + + // Stop all monitors + for _, client := range pm.monitors { + client.StopMonitor() + } +} + +// Close stops monitoring and cleans up resources +func (pm *PeerMonitor) Close() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + // Stop and close all clients + for siteID, client := range pm.monitors { + client.StopMonitor() + client.Close() + delete(pm.monitors, siteID) + } + + pm.running = false +} + +// TestPeer tests connectivity to a specific peer +func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { + pm.mutex.Lock() + client, exists := pm.monitors[siteID] + pm.mutex.Unlock() + + if !exists { + return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) + } + + ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) + defer cancel() + + connected, rtt := client.TestConnection(ctx) + return connected, rtt, nil +} + +// TestAllPeers tests connectivity to all peers +func (pm *PeerMonitor) TestAllPeers() map[int]struct { + Connected bool + RTT time.Duration +} { + pm.mutex.Lock() + peers := make(map[int]*wgtester.Client, len(pm.monitors)) + for siteID, client := range pm.monitors { + peers[siteID] = client + } + pm.mutex.Unlock() + + results := make(map[int]struct { + Connected bool + RTT time.Duration + }) + for siteID, client := range peers { + ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) + connected, rtt := client.TestConnection(ctx) + cancel() + + results[siteID] = struct { + Connected bool + RTT time.Duration + }{ + Connected: connected, + RTT: rtt, + } + } + + return results +}