mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
Add peer monitor
This commit is contained in:
116
common.go
116
common.go
@@ -6,11 +6,11 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
@@ -56,10 +56,12 @@ type EncryptedHolePunchMessage struct {
|
||||
}
|
||||
|
||||
var (
|
||||
peerMonitor *peermonitor.PeerMonitor
|
||||
stopHolepunch chan struct{}
|
||||
stopRegister chan struct{}
|
||||
olmToken string
|
||||
gerbilServerPubKey string
|
||||
peerStatusMap map[int]bool
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -358,87 +360,6 @@ func keepSendingRegistration(olm *websocket.Client, publicKey string) {
|
||||
}
|
||||
}
|
||||
|
||||
func monitorConnection(dev *device.Device, onTimeout func()) {
|
||||
const (
|
||||
checkInterval = 100 * time.Millisecond // Check every 0.1 seconds
|
||||
timeout = 500 * time.Millisecond // Total timeout of 1.5 seconds
|
||||
)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeoutTimer := time.NewTimer(timeout)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
// var lastSent uint64
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Get the current device statistics
|
||||
deviceInfo, err := dev.IpcGet()
|
||||
if err != nil {
|
||||
logger.Error("Failed to get device statistics: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the statistics from the IPC output
|
||||
stats := parseStatistics(deviceInfo)
|
||||
|
||||
logger.Info("Received: %d, Sent: %d", stats.received, stats.sent)
|
||||
|
||||
// Check if we've received any new bytes
|
||||
if stats.received > 0 {
|
||||
// Connection is successful, we received data
|
||||
logger.Info("Connection established - received bytes detected")
|
||||
return
|
||||
}
|
||||
|
||||
// Update the last known values
|
||||
// lastSent = stats.sent
|
||||
|
||||
case <-timeoutTimer.C:
|
||||
// We've hit our timeout without seeing any received bytes
|
||||
logger.Warn("Connection timeout - no data received within %v", timeout)
|
||||
onTimeout()
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// statistics holds the parsed byte counts from the device
|
||||
type statistics struct {
|
||||
received uint64
|
||||
sent uint64
|
||||
}
|
||||
|
||||
// parseStatistics extracts the received and sent byte counts from the device info string
|
||||
func parseStatistics(info string) statistics {
|
||||
var stats statistics
|
||||
|
||||
// Split the device info into lines
|
||||
lines := strings.Split(info, "\n")
|
||||
|
||||
// Look for the transfer_receive and transfer_send lines
|
||||
for _, line := range lines {
|
||||
if strings.HasPrefix(line, "rx_bytes=") {
|
||||
valueStr := strings.TrimPrefix(line, "rx_bytes=")
|
||||
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
|
||||
stats.received = value
|
||||
}
|
||||
} else if strings.HasPrefix(line, "tx_bytes=") {
|
||||
valueStr := strings.TrimPrefix(line, "tx_bytes=")
|
||||
if value, err := strconv.ParseUint(valueStr, 10, 64); err == nil {
|
||||
stats.sent = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats
|
||||
}
|
||||
|
||||
func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
if maxPort < minPort {
|
||||
return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort)
|
||||
@@ -474,3 +395,34 @@ func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) {
|
||||
|
||||
return 0, fmt.Errorf("no available UDP ports found in range %d-%d", minPort, maxPort)
|
||||
}
|
||||
|
||||
func handlePeerStatusChange(siteID int, connected bool, rtt time.Duration) {
|
||||
// Check if status has changed
|
||||
prevStatus, exists := peerStatusMap[siteID]
|
||||
if !exists || prevStatus != connected {
|
||||
if connected {
|
||||
logger.Info("Peer %d is now connected (RTT: %v)", siteID, rtt)
|
||||
// Add any actions you want to take when a peer connects
|
||||
|
||||
// Example: try to send a relay message if this is the first peer to connect
|
||||
if !prevStatus && !exists {
|
||||
// This is a new connection, not just a status update
|
||||
go func() {
|
||||
// Give wireguard a moment to establish properly
|
||||
// time.Sleep(500 * time.Millisecond)
|
||||
// if olm != nil {
|
||||
// if err := sendRelay(olm); err != nil {
|
||||
// logger.Error("Failed to send relay message: %v", err)
|
||||
// }
|
||||
// }
|
||||
}()
|
||||
}
|
||||
} else {
|
||||
logger.Warn("Peer %d is disconnected", siteID)
|
||||
// Add any actions you want to take when a peer disconnects
|
||||
}
|
||||
|
||||
// Update status map
|
||||
peerStatusMap[siteID] = connected
|
||||
}
|
||||
}
|
||||
|
||||
44
main.go
44
main.go
@@ -15,8 +15,10 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/olm/peermonitor"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
"github.com/vishvananda/netlink"
|
||||
|
||||
@@ -133,6 +135,16 @@ func main() {
|
||||
|
||||
stopHolepunch = make(chan struct{})
|
||||
stopRegister = make(chan struct{})
|
||||
peerStatusMap = make(map[int]bool)
|
||||
|
||||
// Initialize the peer monitor
|
||||
peerMonitor = peermonitor.NewPeerMonitor(handlePeerStatusChange)
|
||||
defer peerMonitor.Close()
|
||||
|
||||
// Set custom monitoring parameters if needed
|
||||
peerMonitor.SetInterval(5 * time.Second)
|
||||
peerMonitor.SetTimeout(500 * time.Millisecond)
|
||||
peerMonitor.SetMaxAttempts(3)
|
||||
|
||||
// if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values
|
||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||
@@ -382,6 +394,13 @@ func main() {
|
||||
configBuilder.WriteString(fmt.Sprintf("allowed_ip=%s\n", allowedIpStr))
|
||||
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
|
||||
configBuilder.WriteString("persistent_keepalive_interval=1\n")
|
||||
|
||||
err = peerMonitor.AddPeer(site.SiteId, siteHost)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to setup monitoring for site %d: %v", site.SiteId, err)
|
||||
} else {
|
||||
logger.Info("Started monitoring for site %d at %s", site.SiteId, siteHost)
|
||||
}
|
||||
}
|
||||
|
||||
config := configBuilder.String()
|
||||
@@ -406,30 +425,7 @@ func main() {
|
||||
|
||||
close(stopHolepunch)
|
||||
|
||||
// Monitor the connection for activity
|
||||
monitorConnection(dev, func() { // TODO: this now has to be per site
|
||||
// host, err := resolveDomain(endpoint)
|
||||
// if err != nil {
|
||||
// logger.Error("Failed to resolve endpoint: %v", err)
|
||||
// return
|
||||
// }
|
||||
|
||||
// // Configure WireGuard
|
||||
// config := fmt.Sprintf(`private_key=%s
|
||||
// public_key=%s
|
||||
// allowed_ip=%s/32
|
||||
// endpoint=%s:21820
|
||||
// persistent_keepalive_interval=1`, fixKey(privateKey.String()), fixKey(wgData.PublicKey), wgData.ServerIP, host)
|
||||
|
||||
// err = dev.IpcSet(config)
|
||||
// if err != nil {
|
||||
// logger.Error("Failed to configure WireGuard device: %v", err)
|
||||
// }
|
||||
|
||||
// logger.Info("Adjusted to point to relay!")
|
||||
|
||||
// sendRelay(olm)
|
||||
})
|
||||
peerMonitor.Start()
|
||||
|
||||
logger.Info("WireGuard device created.")
|
||||
})
|
||||
|
||||
232
peermonitor/peermonitor.go
Normal file
232
peermonitor/peermonitor.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package peermonitor
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/olm/wgtester"
|
||||
)
|
||||
|
||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*wgtester.Client
|
||||
callback PeerMonitorCallback
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
interval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
}
|
||||
|
||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||
func NewPeerMonitor(callback PeerMonitorCallback) *PeerMonitor {
|
||||
return &PeerMonitor{
|
||||
monitors: make(map[int]*wgtester.Client),
|
||||
callback: callback,
|
||||
interval: 5 * time.Second, // Default check interval
|
||||
timeout: 500 * time.Millisecond,
|
||||
maxAttempts: 3,
|
||||
}
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to monitor
|
||||
func (pm *PeerMonitor) AddPeer(siteID int, endpoint string) error {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Check if we're already monitoring this peer
|
||||
if _, exists := pm.monitors[siteID]; exists {
|
||||
// Update the endpoint instead of creating a new monitor
|
||||
pm.RemovePeer(siteID)
|
||||
}
|
||||
|
||||
// Add UDP port if not present, assuming default WireGuard port
|
||||
if _, _, err := net.SplitHostPort(endpoint); err != nil {
|
||||
endpoint = endpoint + ":51820" // Default WireGuard port
|
||||
}
|
||||
|
||||
client, err := wgtester.NewClient(endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Configure the client with our settings
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
|
||||
// Store the client
|
||||
pm.monitors[siteID] = client
|
||||
|
||||
// If monitor is already running, start monitoring this peer
|
||||
if pm.running {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
err = client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.callback(siteIDCopy, status.Connected, status.RTT)
|
||||
})
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// RemovePeer stops monitoring a peer and removes it from the monitor
|
||||
func (pm *PeerMonitor) RemovePeer(siteID int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
client, exists := pm.monitors[siteID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
// Start begins monitoring all peers
|
||||
func (pm *PeerMonitor) Start() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if pm.running {
|
||||
return // Already running
|
||||
}
|
||||
|
||||
pm.running = true
|
||||
|
||||
// Start monitoring all peers
|
||||
for siteID, client := range pm.monitors {
|
||||
siteIDCopy := siteID // Create a copy for the closure
|
||||
client.StartMonitor(func(status wgtester.ConnectionStatus) {
|
||||
pm.callback(siteIDCopy, status.Connected, status.RTT)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Stop stops monitoring all peers
|
||||
func (pm *PeerMonitor) Stop() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
if !pm.running {
|
||||
return
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
|
||||
// Stop all monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
}
|
||||
}
|
||||
|
||||
// Close stops monitoring and cleans up resources
|
||||
func (pm *PeerMonitor) Close() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
// Stop and close all clients
|
||||
for siteID, client := range pm.monitors {
|
||||
client.StopMonitor()
|
||||
client.Close()
|
||||
delete(pm.monitors, siteID)
|
||||
}
|
||||
|
||||
pm.running = false
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*wgtester.Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
Reference in New Issue
Block a user