mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
1
main.go
1
main.go
@@ -219,6 +219,7 @@ func runOlmMainWithArgs(ctx context.Context, cancel context.CancelFunc, signalCt
|
|||||||
Agent: "Olm CLI",
|
Agent: "Olm CLI",
|
||||||
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
OnExit: cancel, // Pass cancel function directly to trigger shutdown
|
||||||
OnTerminated: cancel,
|
OnTerminated: cancel,
|
||||||
|
PprofAddr: ":4444", // TODO: REMOVE OR MAKE CONFIGURABLE
|
||||||
}
|
}
|
||||||
|
|
||||||
olm.Init(ctx, olmConfig)
|
olm.Init(ctx, olmConfig)
|
||||||
|
|||||||
12
olm/olm.go
12
olm/olm.go
@@ -5,6 +5,8 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -101,6 +103,16 @@ func Init(ctx context.Context, config GlobalConfig) {
|
|||||||
|
|
||||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||||
|
|
||||||
|
// Start pprof server if enabled
|
||||||
|
if config.PprofAddr != "" {
|
||||||
|
go func() {
|
||||||
|
logger.Info("Starting pprof server on %s", config.PprofAddr)
|
||||||
|
if err := http.ListenAndServe(config.PprofAddr, nil); err != nil {
|
||||||
|
logger.Error("Failed to start pprof server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
logger.Debug("Checking permissions for native interface")
|
logger.Debug("Checking permissions for native interface")
|
||||||
err := permissions.CheckNativeInterfacePermissions()
|
err := permissions.CheckNativeInterfacePermissions()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ type GlobalConfig struct {
|
|||||||
Version string
|
Version string
|
||||||
Agent string
|
Agent string
|
||||||
|
|
||||||
|
// Debugging
|
||||||
|
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
|
||||||
|
|
||||||
// Callbacks
|
// Callbacks
|
||||||
OnRegistered func()
|
OnRegistered func()
|
||||||
OnConnected func()
|
OnConnected func()
|
||||||
|
|||||||
@@ -61,6 +61,13 @@ type PeerMonitor struct {
|
|||||||
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
holepunchMaxAttempts int // max consecutive failures before triggering relay
|
||||||
holepunchFailures map[int]int // siteID -> consecutive failure count
|
holepunchFailures map[int]int // siteID -> consecutive failure count
|
||||||
|
|
||||||
|
// Exponential backoff fields for holepunch monitor
|
||||||
|
holepunchMinInterval time.Duration // Minimum interval (initial)
|
||||||
|
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
|
||||||
|
holepunchBackoffMultiplier float64 // Multiplier for each stable check
|
||||||
|
holepunchStableCount map[int]int // siteID -> consecutive stable status count
|
||||||
|
holepunchCurrentInterval time.Duration // Current interval with backoff applied
|
||||||
|
|
||||||
// Rapid initial test fields
|
// Rapid initial test fields
|
||||||
rapidTestInterval time.Duration // interval between rapid test attempts
|
rapidTestInterval time.Duration // interval between rapid test attempts
|
||||||
rapidTestTimeout time.Duration // timeout for each rapid test attempt
|
rapidTestTimeout time.Duration // timeout for each rapid test attempt
|
||||||
@@ -101,6 +108,12 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
|||||||
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
|
rapidTestMaxAttempts: 5, // 5 attempts = ~1-1.5 seconds total
|
||||||
apiServer: apiServer,
|
apiServer: apiServer,
|
||||||
wgConnectionStatus: make(map[int]bool),
|
wgConnectionStatus: make(map[int]bool),
|
||||||
|
// Exponential backoff settings for holepunch monitor
|
||||||
|
holepunchMinInterval: 2 * time.Second,
|
||||||
|
holepunchMaxInterval: 30 * time.Second,
|
||||||
|
holepunchBackoffMultiplier: 1.5,
|
||||||
|
holepunchStableCount: make(map[int]int),
|
||||||
|
holepunchCurrentInterval: 2 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pm.initNetstack(); err != nil {
|
if err := pm.initNetstack(); err != nil {
|
||||||
@@ -172,6 +185,7 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
|
|||||||
client.SetPacketInterval(pm.interval)
|
client.SetPacketInterval(pm.interval)
|
||||||
client.SetTimeout(pm.timeout)
|
client.SetTimeout(pm.timeout)
|
||||||
client.SetMaxAttempts(pm.maxAttempts)
|
client.SetMaxAttempts(pm.maxAttempts)
|
||||||
|
client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable
|
||||||
|
|
||||||
pm.monitors[siteID] = client
|
pm.monitors[siteID] = client
|
||||||
|
|
||||||
@@ -470,31 +484,50 @@ func (pm *PeerMonitor) stopHolepunchMonitor() {
|
|||||||
logger.Info("Stopped holepunch connection monitor")
|
logger.Info("Stopped holepunch connection monitor")
|
||||||
}
|
}
|
||||||
|
|
||||||
// runHolepunchMonitor runs the holepunch monitoring loop
|
// runHolepunchMonitor runs the holepunch monitoring loop with exponential backoff
|
||||||
func (pm *PeerMonitor) runHolepunchMonitor() {
|
func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||||
ticker := time.NewTicker(pm.holepunchInterval)
|
pm.mutex.Lock()
|
||||||
defer ticker.Stop()
|
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
// Do initial check immediately
|
timer := time.NewTimer(0) // Fire immediately for initial check
|
||||||
pm.checkHolepunchEndpoints()
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-pm.holepunchStopChan:
|
case <-pm.holepunchStopChan:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-timer.C:
|
||||||
pm.checkHolepunchEndpoints()
|
anyStatusChanged := pm.checkHolepunchEndpoints()
|
||||||
|
|
||||||
|
pm.mutex.Lock()
|
||||||
|
if anyStatusChanged {
|
||||||
|
// Reset to minimum interval on any status change
|
||||||
|
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||||
|
} else {
|
||||||
|
// Apply exponential backoff when stable
|
||||||
|
newInterval := time.Duration(float64(pm.holepunchCurrentInterval) * pm.holepunchBackoffMultiplier)
|
||||||
|
if newInterval > pm.holepunchMaxInterval {
|
||||||
|
newInterval = pm.holepunchMaxInterval
|
||||||
|
}
|
||||||
|
pm.holepunchCurrentInterval = newInterval
|
||||||
|
}
|
||||||
|
currentInterval := pm.holepunchCurrentInterval
|
||||||
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
timer.Reset(currentInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkHolepunchEndpoints tests all holepunch endpoints
|
// checkHolepunchEndpoints tests all holepunch endpoints
|
||||||
func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
// Returns true if any endpoint's status changed
|
||||||
|
func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
|
||||||
pm.mutex.Lock()
|
pm.mutex.Lock()
|
||||||
// Check if we're still running before doing any work
|
// Check if we're still running before doing any work
|
||||||
if !pm.running {
|
if !pm.running {
|
||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
return
|
return false
|
||||||
}
|
}
|
||||||
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
endpoints := make(map[int]string, len(pm.holepunchEndpoints))
|
||||||
for siteID, endpoint := range pm.holepunchEndpoints {
|
for siteID, endpoint := range pm.holepunchEndpoints {
|
||||||
@@ -504,6 +537,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
maxAttempts := pm.holepunchMaxAttempts
|
maxAttempts := pm.holepunchMaxAttempts
|
||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
|
anyStatusChanged := false
|
||||||
|
|
||||||
for siteID, endpoint := range endpoints {
|
for siteID, endpoint := range endpoints {
|
||||||
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
|
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
|
||||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||||
@@ -529,7 +564,9 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
// Log status changes
|
// Log status changes
|
||||||
if !exists || previousStatus != result.Success {
|
statusChanged := !exists || previousStatus != result.Success
|
||||||
|
if statusChanged {
|
||||||
|
anyStatusChanged = true
|
||||||
if result.Success {
|
if result.Success {
|
||||||
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
logger.Info("Holepunch to site %d (%s) is CONNECTED (RTT: %v)", siteID, endpoint, result.RTT)
|
||||||
} else {
|
} else {
|
||||||
@@ -562,7 +599,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
pm.mutex.Unlock()
|
pm.mutex.Unlock()
|
||||||
|
|
||||||
if !stillRunning {
|
if !stillRunning {
|
||||||
return // Stop processing if shutdown is in progress
|
return anyStatusChanged // Stop processing if shutdown is in progress
|
||||||
}
|
}
|
||||||
|
|
||||||
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
if !result.Success && !isRelayed && failureCount >= maxAttempts {
|
||||||
@@ -579,6 +616,8 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return anyStatusChanged
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
// GetHolepunchStatus returns the current holepunch status for all endpoints
|
||||||
|
|||||||
@@ -36,6 +36,12 @@ type Client struct {
|
|||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
maxAttempts int
|
maxAttempts int
|
||||||
dialer Dialer
|
dialer Dialer
|
||||||
|
|
||||||
|
// Exponential backoff fields
|
||||||
|
minInterval time.Duration // Minimum interval (initial)
|
||||||
|
maxInterval time.Duration // Maximum interval (cap for backoff)
|
||||||
|
backoffMultiplier float64 // Multiplier for each stable check
|
||||||
|
stableCountToBackoff int // Number of stable checks before backing off
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dialer is a function that creates a connection
|
// Dialer is a function that creates a connection
|
||||||
@@ -53,6 +59,10 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
|||||||
serverAddr: serverAddr,
|
serverAddr: serverAddr,
|
||||||
shutdownCh: make(chan struct{}),
|
shutdownCh: make(chan struct{}),
|
||||||
packetInterval: 2 * time.Second,
|
packetInterval: 2 * time.Second,
|
||||||
|
minInterval: 2 * time.Second,
|
||||||
|
maxInterval: 30 * time.Second,
|
||||||
|
backoffMultiplier: 1.5,
|
||||||
|
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
|
||||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||||
maxAttempts: 3, // Default max attempts
|
maxAttempts: 3, // Default max attempts
|
||||||
dialer: dialer,
|
dialer: dialer,
|
||||||
@@ -62,6 +72,7 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
|||||||
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
||||||
func (c *Client) SetPacketInterval(interval time.Duration) {
|
func (c *Client) SetPacketInterval(interval time.Duration) {
|
||||||
c.packetInterval = interval
|
c.packetInterval = interval
|
||||||
|
c.minInterval = interval
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetTimeout changes the timeout for waiting for responses
|
// SetTimeout changes the timeout for waiting for responses
|
||||||
@@ -74,6 +85,16 @@ func (c *Client) SetMaxAttempts(attempts int) {
|
|||||||
c.maxAttempts = attempts
|
c.maxAttempts = attempts
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetMaxInterval sets the maximum backoff interval
|
||||||
|
func (c *Client) SetMaxInterval(interval time.Duration) {
|
||||||
|
c.maxInterval = interval
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBackoffMultiplier sets the multiplier for exponential backoff
|
||||||
|
func (c *Client) SetBackoffMultiplier(multiplier float64) {
|
||||||
|
c.backoffMultiplier = multiplier
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateServerAddr updates the server address and resets the connection
|
// UpdateServerAddr updates the server address and resets the connection
|
||||||
func (c *Client) UpdateServerAddr(serverAddr string) {
|
func (c *Client) UpdateServerAddr(serverAddr string) {
|
||||||
c.connLock.Lock()
|
c.connLock.Lock()
|
||||||
@@ -138,6 +159,9 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
|||||||
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
binary.BigEndian.PutUint32(packet[0:4], magicHeader)
|
||||||
packet[4] = packetTypeRequest
|
packet[4] = packetTypeRequest
|
||||||
|
|
||||||
|
// Reusable response buffer
|
||||||
|
responseBuffer := make([]byte, packetSize)
|
||||||
|
|
||||||
// Send multiple attempts as specified
|
// Send multiple attempts as specified
|
||||||
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
for attempt := 0; attempt < c.maxAttempts; attempt++ {
|
||||||
select {
|
select {
|
||||||
@@ -157,20 +181,17 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
|||||||
return false, 0
|
return false, 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// logger.Debug("Attempting to send monitor packet to %s", c.serverAddr)
|
|
||||||
_, err := c.conn.Write(packet)
|
_, err := c.conn.Write(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.connLock.Unlock()
|
c.connLock.Unlock()
|
||||||
logger.Info("Error sending packet: %v", err)
|
logger.Info("Error sending packet: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
// logger.Debug("Successfully sent monitor packet")
|
|
||||||
|
|
||||||
// Set read deadline
|
// Set read deadline
|
||||||
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
c.conn.SetReadDeadline(time.Now().Add(c.timeout))
|
||||||
|
|
||||||
// Wait for response
|
// Wait for response
|
||||||
responseBuffer := make([]byte, packetSize)
|
|
||||||
n, err := c.conn.Read(responseBuffer)
|
n, err := c.conn.Read(responseBuffer)
|
||||||
c.connLock.Unlock()
|
c.connLock.Unlock()
|
||||||
|
|
||||||
@@ -238,28 +259,50 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
|
|||||||
go func() {
|
go func() {
|
||||||
var lastConnected bool
|
var lastConnected bool
|
||||||
firstRun := true
|
firstRun := true
|
||||||
|
stableCount := 0
|
||||||
|
currentInterval := c.minInterval
|
||||||
|
|
||||||
ticker := time.NewTicker(c.packetInterval)
|
timer := time.NewTimer(currentInterval)
|
||||||
defer ticker.Stop()
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-c.shutdownCh:
|
case <-c.shutdownCh:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-timer.C:
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||||
connected, rtt := c.TestConnection(ctx)
|
connected, rtt := c.TestConnection(ctx)
|
||||||
cancel()
|
cancel()
|
||||||
|
|
||||||
|
statusChanged := connected != lastConnected
|
||||||
|
|
||||||
// Callback if status changed or it's the first check
|
// Callback if status changed or it's the first check
|
||||||
if connected != lastConnected || firstRun {
|
if statusChanged || firstRun {
|
||||||
callback(ConnectionStatus{
|
callback(ConnectionStatus{
|
||||||
Connected: connected,
|
Connected: connected,
|
||||||
RTT: rtt,
|
RTT: rtt,
|
||||||
})
|
})
|
||||||
lastConnected = connected
|
lastConnected = connected
|
||||||
firstRun = false
|
firstRun = false
|
||||||
|
// Reset backoff on status change
|
||||||
|
stableCount = 0
|
||||||
|
currentInterval = c.minInterval
|
||||||
|
} else {
|
||||||
|
// Status is stable, increment counter
|
||||||
|
stableCount++
|
||||||
|
|
||||||
|
// Apply exponential backoff after stable threshold
|
||||||
|
if stableCount >= c.stableCountToBackoff {
|
||||||
|
newInterval := time.Duration(float64(currentInterval) * c.backoffMultiplier)
|
||||||
|
if newInterval > c.maxInterval {
|
||||||
|
newInterval = c.maxInterval
|
||||||
}
|
}
|
||||||
|
currentInterval = newInterval
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset timer with current interval
|
||||||
|
timer.Reset(currentInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
Reference in New Issue
Block a user