mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
2
go.mod
2
go.mod
@@ -30,3 +30,5 @@ require (
|
||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
|
||||
)
|
||||
|
||||
replace github.com/fosrl/newt => ../newt
|
||||
|
||||
2
go.sum
2
go.sum
@@ -1,7 +1,5 @@
|
||||
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
|
||||
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
|
||||
github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4=
|
||||
github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
|
||||
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
|
||||
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
|
||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
||||
|
||||
@@ -186,7 +186,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// Send handshake acknowledgment back to server with retry
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
|
||||
56
olm/olm.go
56
olm/olm.go
@@ -46,10 +46,10 @@ type Olm struct {
|
||||
holePunchManager *holepunch.Manager
|
||||
peerManager *peers.PeerManager
|
||||
// Power mode management
|
||||
currentPowerMode string
|
||||
powerModeMu sync.Mutex
|
||||
wakeUpTimer *time.Timer
|
||||
wakeUpDebounce time.Duration
|
||||
currentPowerMode string
|
||||
powerModeMu sync.Mutex
|
||||
wakeUpTimer *time.Timer
|
||||
wakeUpDebounce time.Duration
|
||||
|
||||
olmCtx context.Context
|
||||
tunnelCancel context.CancelFunc
|
||||
@@ -134,7 +134,7 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
logger.SetOutput(file)
|
||||
logFile = file
|
||||
}
|
||||
|
||||
|
||||
if config.WakeUpDebounce == 0 {
|
||||
config.WakeUpDebounce = 3 * time.Second
|
||||
}
|
||||
@@ -628,23 +628,28 @@ func (o *Olm) SetPowerMode(mode string) error {
|
||||
logger.Info("Switching to low power mode")
|
||||
|
||||
if o.websocket != nil {
|
||||
logger.Info("Closing websocket connection for low power mode")
|
||||
if err := o.websocket.Close(); err != nil {
|
||||
logger.Error("Error closing websocket: %v", err)
|
||||
logger.Info("Disconnecting websocket for low power mode")
|
||||
if err := o.websocket.Disconnect(); err != nil {
|
||||
logger.Error("Error disconnecting websocket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
lowPowerInterval := 10 * time.Minute
|
||||
|
||||
if o.peerManager != nil {
|
||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||
if peerMonitor != nil {
|
||||
lowPowerInterval := 10 * time.Minute
|
||||
peerMonitor.SetInterval(lowPowerInterval)
|
||||
peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval)
|
||||
peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval)
|
||||
peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval)
|
||||
logger.Info("Set monitoring intervals to 10 minutes for low power mode")
|
||||
}
|
||||
o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable
|
||||
}
|
||||
|
||||
if o.holePunchManager != nil {
|
||||
o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval)
|
||||
}
|
||||
|
||||
o.currentPowerMode = "low"
|
||||
logger.Info("Switched to low power mode")
|
||||
|
||||
@@ -673,20 +678,8 @@ func (o *Olm) SetPowerMode(mode string) error {
|
||||
}
|
||||
|
||||
logger.Info("Debounce complete, switching to normal power mode")
|
||||
|
||||
// Restore intervals and reconnect websocket
|
||||
if o.peerManager != nil {
|
||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.ResetHolepunchInterval()
|
||||
peerMonitor.ResetInterval()
|
||||
}
|
||||
|
||||
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
|
||||
}
|
||||
|
||||
|
||||
logger.Info("Reconnecting websocket for normal power mode")
|
||||
|
||||
if o.websocket != nil {
|
||||
if err := o.websocket.Connect(); err != nil {
|
||||
logger.Error("Failed to reconnect websocket: %v", err)
|
||||
@@ -694,6 +687,21 @@ func (o *Olm) SetPowerMode(mode string) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Restore intervals and reconnect websocket
|
||||
if o.peerManager != nil {
|
||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.ResetPeerHolepunchInterval()
|
||||
peerMonitor.ResetPeerInterval()
|
||||
}
|
||||
|
||||
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
|
||||
}
|
||||
|
||||
if o.holePunchManager != nil {
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
|
||||
o.currentPowerMode = "normal"
|
||||
logger.Info("Switched to normal power mode")
|
||||
})
|
||||
|
||||
@@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
|
||||
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
|
||||
_ = o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetInterval()
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||
|
||||
@@ -28,14 +28,12 @@ import (
|
||||
|
||||
// PeerMonitor handles monitoring the connection status to multiple WireGuard peers
|
||||
type PeerMonitor struct {
|
||||
monitors map[int]*Client
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
defaultInterval time.Duration
|
||||
interval time.Duration
|
||||
monitors map[int]*Client
|
||||
mutex sync.Mutex
|
||||
running bool
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
maxAttempts int
|
||||
wsClient *websocket.Client
|
||||
|
||||
// Netstack fields
|
||||
middleDev *middleDevice.MiddleDevice
|
||||
@@ -54,7 +52,8 @@ type PeerMonitor struct {
|
||||
holepunchTimeout time.Duration
|
||||
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
|
||||
holepunchStatus map[int]bool // siteID -> connected status
|
||||
holepunchStopChan chan struct{}
|
||||
holepunchStopChan chan struct{}
|
||||
holepunchUpdateChan chan struct{}
|
||||
|
||||
// Relay tracking fields
|
||||
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
|
||||
@@ -87,8 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
pm := &PeerMonitor{
|
||||
monitors: make(map[int]*Client),
|
||||
defaultInterval: 2 * time.Second,
|
||||
interval: 2 * time.Second, // Default check interval (faster)
|
||||
timeout: 3 * time.Second,
|
||||
maxAttempts: 3,
|
||||
wsClient: wsClient,
|
||||
@@ -118,6 +115,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
holepunchBackoffMultiplier: 1.5,
|
||||
holepunchStableCount: make(map[int]int),
|
||||
holepunchCurrentInterval: 2 * time.Second,
|
||||
holepunchUpdateChan: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
if err := pm.initNetstack(); err != nil {
|
||||
@@ -133,82 +131,76 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
|
||||
}
|
||||
|
||||
// SetInterval changes how frequently peers are checked
|
||||
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
|
||||
func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = interval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(interval)
|
||||
client.SetPacketInterval(minInterval, maxInterval)
|
||||
}
|
||||
|
||||
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
|
||||
}
|
||||
|
||||
func (pm *PeerMonitor) ResetInterval() {
|
||||
func (pm *PeerMonitor) ResetPeerInterval() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.interval = pm.defaultInterval
|
||||
|
||||
// Update interval for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetPacketInterval(pm.defaultInterval)
|
||||
client.ResetPacketInterval()
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
|
||||
// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
|
||||
func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.timeout = timeout
|
||||
|
||||
// Update timeout for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetTimeout(timeout)
|
||||
}
|
||||
}
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.maxAttempts = attempts
|
||||
|
||||
// Update max attempts for all existing monitors
|
||||
for _, client := range pm.monitors {
|
||||
client.SetMaxAttempts(attempts)
|
||||
}
|
||||
}
|
||||
|
||||
// SetHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
|
||||
func (pm *PeerMonitor) SetHolepunchInterval(minInterval, maxInterval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.holepunchMinInterval = minInterval
|
||||
pm.holepunchMaxInterval = maxInterval
|
||||
// Reset current interval to the new minimum
|
||||
pm.holepunchCurrentInterval = minInterval
|
||||
updateChan := pm.holepunchUpdateChan
|
||||
pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if updateChan != nil {
|
||||
select {
|
||||
case updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
|
||||
func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Duration) {
|
||||
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
|
||||
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
return pm.holepunchMinInterval, pm.holepunchMaxInterval
|
||||
}
|
||||
|
||||
func (pm *PeerMonitor) ResetHolepunchInterval() {
|
||||
func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
|
||||
pm.mutex.Lock()
|
||||
defer pm.mutex.Unlock()
|
||||
|
||||
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
|
||||
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
|
||||
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
|
||||
updateChan := pm.holepunchUpdateChan
|
||||
pm.mutex.Unlock()
|
||||
|
||||
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if updateChan != nil {
|
||||
select {
|
||||
case updateChan <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AddPeer adds a new peer to monitor
|
||||
@@ -226,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
|
||||
return err
|
||||
}
|
||||
|
||||
client.SetPacketInterval(pm.interval)
|
||||
client.SetTimeout(pm.timeout)
|
||||
client.SetMaxAttempts(pm.maxAttempts)
|
||||
client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable
|
||||
|
||||
pm.monitors[siteID] = client
|
||||
|
||||
pm.holepunchEndpoints[siteID] = holepunchEndpoint
|
||||
@@ -541,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
|
||||
select {
|
||||
case <-pm.holepunchStopChan:
|
||||
return
|
||||
case <-pm.holepunchUpdateChan:
|
||||
// Interval settings changed, reset to minimum
|
||||
pm.mutex.Lock()
|
||||
pm.holepunchCurrentInterval = pm.holepunchMinInterval
|
||||
currentInterval := pm.holepunchCurrentInterval
|
||||
pm.mutex.Unlock()
|
||||
|
||||
timer.Reset(currentInterval)
|
||||
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
|
||||
case <-timer.C:
|
||||
anyStatusChanged := pm.checkHolepunchEndpoints()
|
||||
|
||||
@@ -584,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
|
||||
anyStatusChanged := false
|
||||
|
||||
for siteID, endpoint := range endpoints {
|
||||
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
|
||||
logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
|
||||
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
|
||||
|
||||
pm.mutex.Lock()
|
||||
@@ -733,55 +729,55 @@ func (pm *PeerMonitor) Close() {
|
||||
logger.Debug("PeerMonitor: Cleanup complete")
|
||||
}
|
||||
|
||||
// TestPeer tests connectivity to a specific peer
|
||||
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
pm.mutex.Lock()
|
||||
client, exists := pm.monitors[siteID]
|
||||
pm.mutex.Unlock()
|
||||
// // TestPeer tests connectivity to a specific peer
|
||||
// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
|
||||
// pm.mutex.Lock()
|
||||
// client, exists := pm.monitors[siteID]
|
||||
// pm.mutex.Unlock()
|
||||
|
||||
if !exists {
|
||||
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
}
|
||||
// if !exists {
|
||||
// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
|
||||
// }
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
defer cancel()
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
// defer cancel()
|
||||
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
return connected, rtt, nil
|
||||
}
|
||||
// connected, rtt := client.TestPeerConnection(ctx)
|
||||
// return connected, rtt, nil
|
||||
// }
|
||||
|
||||
// TestAllPeers tests connectivity to all peers
|
||||
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
} {
|
||||
pm.mutex.Lock()
|
||||
peers := make(map[int]*Client, len(pm.monitors))
|
||||
for siteID, client := range pm.monitors {
|
||||
peers[siteID] = client
|
||||
}
|
||||
pm.mutex.Unlock()
|
||||
// // TestAllPeers tests connectivity to all peers
|
||||
// func (pm *PeerMonitor) TestAllPeers() map[int]struct {
|
||||
// Connected bool
|
||||
// RTT time.Duration
|
||||
// } {
|
||||
// pm.mutex.Lock()
|
||||
// peers := make(map[int]*Client, len(pm.monitors))
|
||||
// for siteID, client := range pm.monitors {
|
||||
// peers[siteID] = client
|
||||
// }
|
||||
// pm.mutex.Unlock()
|
||||
|
||||
results := make(map[int]struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
})
|
||||
for siteID, client := range peers {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
connected, rtt := client.TestConnection(ctx)
|
||||
cancel()
|
||||
// results := make(map[int]struct {
|
||||
// Connected bool
|
||||
// RTT time.Duration
|
||||
// })
|
||||
// for siteID, client := range peers {
|
||||
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
|
||||
// connected, rtt := client.TestPeerConnection(ctx)
|
||||
// cancel()
|
||||
|
||||
results[siteID] = struct {
|
||||
Connected bool
|
||||
RTT time.Duration
|
||||
}{
|
||||
Connected: connected,
|
||||
RTT: rtt,
|
||||
}
|
||||
}
|
||||
// results[siteID] = struct {
|
||||
// Connected bool
|
||||
// RTT time.Duration
|
||||
// }{
|
||||
// Connected: connected,
|
||||
// RTT: rtt,
|
||||
// }
|
||||
// }
|
||||
|
||||
return results
|
||||
}
|
||||
// return results
|
||||
// }
|
||||
|
||||
// initNetstack initializes the gvisor netstack
|
||||
func (pm *PeerMonitor) initNetstack() error {
|
||||
|
||||
@@ -32,16 +32,19 @@ type Client struct {
|
||||
monitorLock sync.Mutex
|
||||
connLock sync.Mutex // Protects connection operations
|
||||
shutdownCh chan struct{}
|
||||
updateCh chan struct{}
|
||||
packetInterval time.Duration
|
||||
timeout time.Duration
|
||||
maxAttempts int
|
||||
dialer Dialer
|
||||
|
||||
// Exponential backoff fields
|
||||
minInterval time.Duration // Minimum interval (initial)
|
||||
maxInterval time.Duration // Maximum interval (cap for backoff)
|
||||
backoffMultiplier float64 // Multiplier for each stable check
|
||||
stableCountToBackoff int // Number of stable checks before backing off
|
||||
defaultMinInterval time.Duration // Default minimum interval (initial)
|
||||
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
|
||||
minInterval time.Duration // Minimum interval (initial)
|
||||
maxInterval time.Duration // Maximum interval (cap for backoff)
|
||||
backoffMultiplier float64 // Multiplier for each stable check
|
||||
stableCountToBackoff int // Number of stable checks before backing off
|
||||
}
|
||||
|
||||
// Dialer is a function that creates a connection
|
||||
@@ -56,43 +59,59 @@ type ConnectionStatus struct {
|
||||
// NewClient creates a new connection test client
|
||||
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
|
||||
return &Client{
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
packetInterval: 2 * time.Second,
|
||||
minInterval: 2 * time.Second,
|
||||
maxInterval: 30 * time.Second,
|
||||
backoffMultiplier: 1.5,
|
||||
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
serverAddr: serverAddr,
|
||||
shutdownCh: make(chan struct{}),
|
||||
updateCh: make(chan struct{}, 1),
|
||||
packetInterval: 2 * time.Second,
|
||||
defaultMinInterval: 2 * time.Second,
|
||||
defaultMaxInterval: 30 * time.Second,
|
||||
minInterval: 2 * time.Second,
|
||||
maxInterval: 30 * time.Second,
|
||||
backoffMultiplier: 1.5,
|
||||
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
|
||||
timeout: 500 * time.Millisecond, // Timeout for individual packets
|
||||
maxAttempts: 3, // Default max attempts
|
||||
dialer: dialer,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetPacketInterval changes how frequently packets are sent in monitor mode
|
||||
func (c *Client) SetPacketInterval(interval time.Duration) {
|
||||
c.packetInterval = interval
|
||||
c.minInterval = interval
|
||||
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
|
||||
c.monitorLock.Lock()
|
||||
c.packetInterval = minInterval
|
||||
c.minInterval = minInterval
|
||||
c.maxInterval = maxInterval
|
||||
updateCh := c.updateCh
|
||||
monitorRunning := c.monitorRunning
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if monitorRunning && updateCh != nil {
|
||||
select {
|
||||
case updateCh <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetTimeout changes the timeout for waiting for responses
|
||||
func (c *Client) SetTimeout(timeout time.Duration) {
|
||||
c.timeout = timeout
|
||||
}
|
||||
func (c *Client) ResetPacketInterval() {
|
||||
c.monitorLock.Lock()
|
||||
c.packetInterval = c.defaultMinInterval
|
||||
c.minInterval = c.defaultMinInterval
|
||||
c.maxInterval = c.defaultMaxInterval
|
||||
updateCh := c.updateCh
|
||||
monitorRunning := c.monitorRunning
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// SetMaxAttempts changes the maximum number of attempts for TestConnection
|
||||
func (c *Client) SetMaxAttempts(attempts int) {
|
||||
c.maxAttempts = attempts
|
||||
}
|
||||
|
||||
// SetMaxInterval sets the maximum backoff interval
|
||||
func (c *Client) SetMaxInterval(interval time.Duration) {
|
||||
c.maxInterval = interval
|
||||
}
|
||||
|
||||
// SetBackoffMultiplier sets the multiplier for exponential backoff
|
||||
func (c *Client) SetBackoffMultiplier(multiplier float64) {
|
||||
c.backoffMultiplier = multiplier
|
||||
// Signal the goroutine to apply the new interval if running
|
||||
if monitorRunning && updateCh != nil {
|
||||
select {
|
||||
case updateCh <- struct{}{}:
|
||||
default:
|
||||
// Channel full or closed, skip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateServerAddr updates the server address and resets the connection
|
||||
@@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestConnection checks if the connection to the server is working
|
||||
// TestPeerConnection checks if the connection to the server is working
|
||||
// Returns true if connected, false otherwise
|
||||
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
|
||||
logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
|
||||
if err := c.ensureConnection(); err != nil {
|
||||
logger.Warn("Failed to ensure connection: %v", err)
|
||||
return false, 0
|
||||
@@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
|
||||
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
return c.TestConnection(ctx)
|
||||
return c.TestPeerConnection(ctx)
|
||||
}
|
||||
|
||||
// MonitorCallback is the function type for connection status change callbacks
|
||||
@@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
|
||||
select {
|
||||
case <-c.shutdownCh:
|
||||
return
|
||||
case <-c.updateCh:
|
||||
// Interval settings changed, reset to minimum
|
||||
c.monitorLock.Lock()
|
||||
currentInterval = c.minInterval
|
||||
c.monitorLock.Unlock()
|
||||
|
||||
// Reset backoff state
|
||||
stableCount = 0
|
||||
|
||||
timer.Reset(currentInterval)
|
||||
logger.Debug("Packet interval updated, reset to %v", currentInterval)
|
||||
case <-timer.C:
|
||||
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
|
||||
connected, rtt := c.TestConnection(ctx)
|
||||
connected, rtt := c.TestPeerConnection(ctx)
|
||||
cancel()
|
||||
|
||||
statusChanged := connected != lastConnected
|
||||
@@ -321,4 +352,4 @@ func (c *Client) StopMonitor() {
|
||||
|
||||
close(c.shutdownCh)
|
||||
c.monitorRunning = false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -236,7 +236,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
Data: data,
|
||||
}
|
||||
|
||||
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
||||
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
|
||||
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
@@ -258,7 +258,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
||||
}
|
||||
err := c.SendMessage(messageType, currentData)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send message: %v", err)
|
||||
logger.Error("websocket: Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
@@ -271,7 +271,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if maxAttempts != -1 && count >= maxAttempts {
|
||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
return
|
||||
}
|
||||
dataMux.Lock()
|
||||
@@ -353,7 +353,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
tokenData := map[string]interface{}{
|
||||
@@ -382,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||
|
||||
// print out the request for debugging
|
||||
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{}
|
||||
@@ -399,7 +399,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
|
||||
// Return AuthError for 401/403 status codes
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
@@ -415,7 +415,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
logger.Error("Failed to decode token response.")
|
||||
logger.Error("websocket: Failed to decode token response.")
|
||||
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
@@ -427,7 +427,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
return "", nil, fmt.Errorf("received empty token from server")
|
||||
}
|
||||
|
||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
|
||||
|
||||
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||
}
|
||||
@@ -442,7 +442,7 @@ func (c *Client) connectWithRetry() {
|
||||
if err != nil {
|
||||
// Check if this is an auth error (401/403)
|
||||
if authErr, ok := err.(*AuthError); ok {
|
||||
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
|
||||
logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr)
|
||||
// Trigger auth error callback if set (this should terminate the tunnel)
|
||||
if c.onAuthError != nil {
|
||||
c.onAuthError(authErr.StatusCode, authErr.Message)
|
||||
@@ -452,7 +452,7 @@ func (c *Client) connectWithRetry() {
|
||||
continue
|
||||
}
|
||||
// For other errors (5xx, network issues), continue retrying
|
||||
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
@@ -505,7 +505,7 @@ func (c *Client) establishConnection() error {
|
||||
|
||||
// Use new TLS configuration method
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Setting up TLS configuration for WebSocket connection")
|
||||
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
|
||||
tlsConfig, err := c.setupTLS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
@@ -519,7 +519,7 @@ func (c *Client) establishConnection() error {
|
||||
dialer.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(u.String(), nil)
|
||||
@@ -537,7 +537,7 @@ func (c *Client) establishConnection() error {
|
||||
|
||||
if c.onConnect != nil {
|
||||
if err := c.onConnect(); err != nil {
|
||||
logger.Error("OnConnect callback failed: %v", err)
|
||||
logger.Error("websocket: OnConnect callback failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -550,9 +550,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Handle new separate certificate configuration
|
||||
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||
logger.Info("Loading separate certificate files for mTLS")
|
||||
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
logger.Info("websocket: Loading separate certificate files for mTLS")
|
||||
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
|
||||
// Load client certificate and key
|
||||
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||
@@ -563,7 +563,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Load CA certificates for remote validation if specified
|
||||
if len(c.tlsConfig.CAFiles) > 0 {
|
||||
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
caCertPool := x509.NewCertPool()
|
||||
for _, caFile := range c.tlsConfig.CAFiles {
|
||||
caCert, err := os.ReadFile(caFile)
|
||||
@@ -589,13 +589,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||
if c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
return c.setupPKCS12TLS()
|
||||
}
|
||||
|
||||
// Legacy fallback using config.TlsClientCert
|
||||
if c.config.TlsClientCert != "" {
|
||||
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
return loadClientCertificate(c.config.TlsClientCert)
|
||||
}
|
||||
|
||||
@@ -630,7 +630,7 @@ func (c *Client) pingMonitor() {
|
||||
// Expected during shutdown
|
||||
return
|
||||
default:
|
||||
logger.Error("Ping failed: %v", err)
|
||||
logger.Error("websocket: Ping failed: %v", err)
|
||||
c.reconnect()
|
||||
return
|
||||
}
|
||||
@@ -663,18 +663,23 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
||||
var msg WSMessage
|
||||
err := c.conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error
|
||||
// Check if we're shutting down or explicitly disconnected before logging error
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown, don't log as error
|
||||
logger.Debug("WebSocket connection closed during shutdown")
|
||||
logger.Debug("websocket: connection closed during shutdown")
|
||||
return
|
||||
default:
|
||||
// Check if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: connection closed: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
// Unexpected error during normal operation
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||
logger.Error("WebSocket read error: %v", err)
|
||||
logger.Error("websocket: read error: %v", err)
|
||||
} else {
|
||||
logger.Debug("WebSocket connection closed: %v", err)
|
||||
logger.Debug("websocket: connection closed: %v", err)
|
||||
}
|
||||
return // triggers reconnect via defer
|
||||
}
|
||||
@@ -696,6 +701,12 @@ func (c *Client) reconnect() {
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
// Don't reconnect if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
// Only reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
@@ -713,7 +724,7 @@ func (c *Client) setConnected(status bool) {
|
||||
|
||||
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
|
||||
// Read the PKCS12 file
|
||||
p12Data, err := os.ReadFile(p12Path)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user