mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
12
olm/olm.go
12
olm/olm.go
@@ -419,6 +419,7 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
config.Holepunch && !config.DisableRelay, // Enable relay only if holepunching is enabled and DisableRelay is false
|
config.Holepunch && !config.DisableRelay, // Enable relay only if holepunching is enabled and DisableRelay is false
|
||||||
middleDev,
|
middleDev,
|
||||||
interfaceIP,
|
interfaceIP,
|
||||||
|
sharedBind, // Pass sharedBind for holepunch testing
|
||||||
)
|
)
|
||||||
|
|
||||||
peerManager = peers.NewPeerManager(dev, peerMonitor, dnsProxy, interfaceName, privateKey)
|
peerManager = peers.NewPeerManager(dev, peerMonitor, dnsProxy, interfaceName, privateKey)
|
||||||
@@ -432,9 +433,20 @@ func StartTunnel(config TunnelConfig) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add holepunch monitoring for this endpoint if holepunching is enabled
|
||||||
|
if config.Holepunch {
|
||||||
|
peerMonitor.AddHolepunchEndpoint(site.SiteId, site.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
logger.Info("Configured peer %s", site.PublicKey)
|
logger.Info("Configured peer %s", site.PublicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
peerMonitor.SetHolepunchStatusCallback(func(siteID int, endpoint string, connected bool, rtt time.Duration) {
|
||||||
|
// This callback is for additional handling if needed
|
||||||
|
// The PeerMonitor already logs status changes
|
||||||
|
logger.Info("+++++++++++++++++++++++++ holepunch monitor callback for site %d, endpoint %s, connected: %v, rtt: %v", siteID, endpoint, connected, rtt)
|
||||||
|
})
|
||||||
|
|
||||||
peerMonitor.Start()
|
peerMonitor.Start()
|
||||||
|
|
||||||
// Set up DNS override to use our DNS proxy
|
// Set up DNS override to use our DNS proxy
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/bind"
|
||||||
|
"github.com/fosrl/newt/holepunch"
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
middleDevice "github.com/fosrl/olm/device"
|
middleDevice "github.com/fosrl/olm/device"
|
||||||
@@ -28,6 +30,9 @@ import (
|
|||||||
// PeerMonitorCallback is the function type for connection status change callbacks
|
// PeerMonitorCallback is the function type for connection status change callbacks
|
||||||
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
type PeerMonitorCallback func(siteID int, connected bool, rtt time.Duration)
|
||||||
|
|
||||||
|
// HolepunchStatusCallback is called when holepunch connection status changes
|
||||||
|
type HolepunchStatusCallback func(siteID int, endpoint string, connected bool, rtt time.Duration)
|
||||||
|
|
||||||
// WireGuardConfig holds the WireGuard configuration for a peer
|
// WireGuardConfig holds the WireGuard configuration for a peer
|
||||||
type WireGuardConfig struct {
|
type WireGuardConfig struct {
|
||||||
SiteID int
|
SiteID int
|
||||||
@@ -62,33 +67,53 @@ type PeerMonitor struct {
|
|||||||
nsCtx context.Context
|
nsCtx context.Context
|
||||||
nsCancel context.CancelFunc
|
nsCancel context.CancelFunc
|
||||||
nsWg sync.WaitGroup
|
nsWg sync.WaitGroup
|
||||||
|
|
||||||
|
// Holepunch testing fields
|
||||||
|
sharedBind *bind.SharedBind
|
||||||
|
holepunchTester *holepunch.HolepunchTester
|
||||||
|
holepunchInterval time.Duration
|
||||||
|
holepunchTimeout time.Duration
|
||||||
|
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||||
|
holepunchStatus map[int]bool // siteID -> connected status
|
||||||
|
holepunchStatusCallback HolepunchStatusCallback
|
||||||
|
holepunchStopChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPeerMonitor creates a new peer monitor with the given callback
|
// NewPeerMonitor creates a new peer monitor with the given callback
|
||||||
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string) *PeerMonitor {
|
func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool, middleDev *middleDevice.MiddleDevice, localIP string, sharedBind *bind.SharedBind) *PeerMonitor {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
pm := &PeerMonitor{
|
pm := &PeerMonitor{
|
||||||
monitors: make(map[int]*Client),
|
monitors: make(map[int]*Client),
|
||||||
configs: make(map[int]*WireGuardConfig),
|
configs: make(map[int]*WireGuardConfig),
|
||||||
callback: callback,
|
callback: callback,
|
||||||
interval: 1 * time.Second, // Default check interval
|
interval: 1 * time.Second, // Default check interval
|
||||||
timeout: 2500 * time.Millisecond,
|
timeout: 2500 * time.Millisecond,
|
||||||
maxAttempts: 15,
|
maxAttempts: 15,
|
||||||
privateKey: privateKey,
|
privateKey: privateKey,
|
||||||
wsClient: wsClient,
|
wsClient: wsClient,
|
||||||
device: device,
|
device: device,
|
||||||
handleRelaySwitch: handleRelaySwitch,
|
handleRelaySwitch: handleRelaySwitch,
|
||||||
middleDev: middleDev,
|
middleDev: middleDev,
|
||||||
localIP: localIP,
|
localIP: localIP,
|
||||||
activePorts: make(map[uint16]bool),
|
activePorts: make(map[uint16]bool),
|
||||||
nsCtx: ctx,
|
nsCtx: ctx,
|
||||||
nsCancel: cancel,
|
nsCancel: cancel,
|
||||||
|
sharedBind: sharedBind,
|
||||||
|
holepunchInterval: 5 * time.Second, // Check holepunch every 5 seconds
|
||||||
|
holepunchTimeout: 3 * time.Second,
|
||||||
|
holepunchEndpoints: make(map[int]string),
|
||||||
|
holepunchStatus: make(map[int]bool),
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pm.initNetstack(); err != nil {
|
if err := pm.initNetstack(); err != nil {
|
||||||
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
|
logger.Error("Failed to initialize netstack for peer monitor: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize holepunch tester if sharedBind is available
|
||||||
|
if sharedBind != nil {
|
||||||
|
pm.holepunchTester = holepunch.NewHolepunchTester(sharedBind)
|
||||||
|
}
|
||||||
|
|
||||||
return pm
|
return pm
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -209,6 +234,8 @@ func (pm *PeerMonitor) Start() {
|
|||||||
}
|
}
|
||||||
logger.Info("Started monitoring peer %d\n", siteID)
|
logger.Info("Started monitoring peer %d\n", siteID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pm.startHolepunchMonitor()
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConnectionStatusChange is called when a peer's connection status changes
|
// handleConnectionStatusChange is called when a peer's connection status changes
|
||||||
@@ -282,6 +309,9 @@ func (pm *PeerMonitor) sendRelay(siteID int) error {
|
|||||||
|
|
||||||
// Stop stops monitoring all peers
|
// Stop stops monitoring all peers
|
||||||
func (pm *PeerMonitor) Stop() {
|
func (pm *PeerMonitor) Stop() {
|
||||||
|
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||||
|
pm.stopHolepunchMonitor()
|
||||||
|
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.mutex.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
@@ -297,8 +327,148 @@ func (pm *PeerMonitor) Stop() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetHolepunchStatusCallback sets the callback for holepunch status changes
|
||||||
|
func (pm *PeerMonitor) SetHolepunchStatusCallback(callback HolepunchStatusCallback) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
pm.holepunchStatusCallback = callback
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddHolepunchEndpoint adds an endpoint to monitor via holepunch magic packets
|
||||||
|
func (pm *PeerMonitor) AddHolepunchEndpoint(siteID int, endpoint string) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
pm.holepunchEndpoints[siteID] = endpoint
|
||||||
|
pm.holepunchStatus[siteID] = false // Initially unknown/disconnected
|
||||||
|
logger.Info("Added holepunch monitoring for site %d at %s", siteID, endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveHolepunchEndpoint removes an endpoint from holepunch monitoring
|
||||||
|
func (pm *PeerMonitor) RemoveHolepunchEndpoint(siteID int) {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
delete(pm.holepunchEndpoints, siteID)
|
||||||
|
delete(pm.holepunchStatus, siteID)
|
||||||
|
logger.Info("Removed holepunch monitoring for site %d", siteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// startHolepunchMonitor starts the holepunch connection monitoring
|
||||||
|
// Note: This function assumes the mutex is already held by the caller (called from Start())
|
||||||
|
func (pm *PeerMonitor) startHolepunchMonitor() error {
|
||||||
|
if pm.holepunchTester == nil {
|
||||||
|
return fmt.Errorf("holepunch tester not initialized (sharedBind not provided)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if pm.holepunchStopChan != nil {
|
||||||
|
return fmt.Errorf("holepunch monitor already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := pm.holepunchTester.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start holepunch tester: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pm.holepunchStopChan = make(chan struct{})
|
||||||
|
|
||||||
|
go pm.runHolepunchMonitor()
|
||||||
|
|
||||||
|
logger.Info("Started holepunch connection monitor")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// stopHolepunchMonitor stops the holepunch connection monitoring
|
||||||
|
func (pm *PeerMonitor) stopHolepunchMonitor() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
stopChan := pm.holepunchStopChan
|
||||||
|
pm.holepunchStopChan = nil
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
if stopChan != nil {
|
||||||
|
close(stopChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
if pm.holepunchTester != nil {
|
||||||
|
pm.holepunchTester.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Stopped holepunch connection monitor")
|
||||||
|
}
|
||||||
|
|
||||||
|
// runHolepunchMonitor runs the holepunch monitoring loop
|
||||||
|
func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||||
|
ticker := time.NewTicker(pm.holepunchInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
// Do initial check immediately
|
||||||
|
pm.checkHolepunchEndpoints()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-pm.holepunchStopChan:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
pm.checkHolepunchEndpoints()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkHolepunchEndpoints tests all holepunch endpoints
|
||||||
|
func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
||||||
|
for siteID, endpoint := range pm.holepunchEndpoints {
|
||||||
|
endpoints[siteID] = endpoint
|
||||||
|
}
|
||||||
|
timeout := pm.holepunchTimeout
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
for siteID, endpoint := range endpoints {
|
||||||
|
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||||
|
|
||||||
|
pm.mutex.Lock()
|
||||||
|
previousStatus, exists := pm.holepunchStatus[siteID]
|
||||||
|
pm.holepunchStatus[siteID] = result.Success
|
||||||
|
callback := pm.holepunchStatusCallback
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
// Log status changes
|
||||||
|
if !exists || previousStatus != result.Success {
|
||||||
|
if result.Success {
|
||||||
|
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
||||||
|
} else {
|
||||||
|
if result.Error != nil {
|
||||||
|
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED: %v", siteID, endpoint, result.Error)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Holepunch to site %d (%s) is DISCONNECTED", siteID, endpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the callback if set
|
||||||
|
if callback != nil {
|
||||||
|
callback(siteID, endpoint, result.Success, result.RTT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
||||||
|
func (pm *PeerMonitor) GetHolepunchStatus() map[int]bool {
|
||||||
|
pm.mutex.Lock()
|
||||||
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
status := make(map[int]bool, len(pm.holepunchStatus))
|
||||||
|
for siteID, connected := range pm.holepunchStatus {
|
||||||
|
status[siteID] = connected
|
||||||
|
}
|
||||||
|
return status
|
||||||
|
}
|
||||||
|
|
||||||
// Close stops monitoring and cleans up resources
|
// Close stops monitoring and cleans up resources
|
||||||
func (pm *PeerMonitor) Close() {
|
func (pm *PeerMonitor) Close() {
|
||||||
|
// Stop holepunch monitor first (outside of mutex to avoid deadlock)
|
||||||
|
pm.stopHolepunchMonitor()
|
||||||
|
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
defer pm.mutex.Unlock()
|
defer pm.mutex.Unlock()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user