Power state getting set correctly

Former-commit-id: 0895156efd
This commit is contained in:
Owen
2026-01-14 16:38:40 -08:00
parent 3470da76fc
commit 3ba1714524
8 changed files with 239 additions and 193 deletions

2
go.mod
View File

@@ -30,3 +30,5 @@ require (
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
) )
replace github.com/fosrl/newt => ../newt

2
go.sum
View File

@@ -1,7 +1,5 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4=
github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

View File

@@ -186,7 +186,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
} }
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry // Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{

View File

@@ -628,23 +628,28 @@ func (o *Olm) SetPowerMode(mode string) error {
logger.Info("Switching to low power mode") logger.Info("Switching to low power mode")
if o.websocket != nil { if o.websocket != nil {
logger.Info("Closing websocket connection for low power mode") logger.Info("Disconnecting websocket for low power mode")
if err := o.websocket.Close(); err != nil { if err := o.websocket.Disconnect(); err != nil {
logger.Error("Error closing websocket: %v", err) logger.Error("Error disconnecting websocket: %v", err)
} }
} }
lowPowerInterval := 10 * time.Minute
if o.peerManager != nil { if o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor() peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil { if peerMonitor != nil {
lowPowerInterval := 10 * time.Minute peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval)
peerMonitor.SetInterval(lowPowerInterval) peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval)
peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval)
logger.Info("Set monitoring intervals to 10 minutes for low power mode") logger.Info("Set monitoring intervals to 10 minutes for low power mode")
} }
o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable
} }
if o.holePunchManager != nil {
o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval)
}
o.currentPowerMode = "low" o.currentPowerMode = "low"
logger.Info("Switched to low power mode") logger.Info("Switched to low power mode")
@@ -674,19 +679,7 @@ func (o *Olm) SetPowerMode(mode string) error {
logger.Info("Debounce complete, switching to normal power mode") logger.Info("Debounce complete, switching to normal power mode")
// Restore intervals and reconnect websocket
if o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil {
peerMonitor.ResetHolepunchInterval()
peerMonitor.ResetInterval()
}
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
}
logger.Info("Reconnecting websocket for normal power mode") logger.Info("Reconnecting websocket for normal power mode")
if o.websocket != nil { if o.websocket != nil {
if err := o.websocket.Connect(); err != nil { if err := o.websocket.Connect(); err != nil {
logger.Error("Failed to reconnect websocket: %v", err) logger.Error("Failed to reconnect websocket: %v", err)
@@ -694,6 +687,21 @@ func (o *Olm) SetPowerMode(mode string) error {
} }
} }
// Restore intervals and reconnect websocket
if o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil {
peerMonitor.ResetPeerHolepunchInterval()
peerMonitor.ResetPeerInterval()
}
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
}
if o.holePunchManager != nil {
o.holePunchManager.ResetServerHolepunchInterval()
}
o.currentPowerMode = "normal" o.currentPowerMode = "normal"
logger.Info("Switched to normal power mode") logger.Info("Switched to normal power mode")
}) })

View File

@@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
_ = o.holePunchManager.TriggerHolePunch() _ = o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetInterval() o.holePunchManager.ResetServerHolepunchInterval()
} }
logger.Info("Successfully updated peer for site %d", updateData.SiteId) logger.Info("Successfully updated peer for site %d", updateData.SiteId)

View File

@@ -31,8 +31,6 @@ type PeerMonitor struct {
monitors map[int]*Client monitors map[int]*Client
mutex sync.Mutex mutex sync.Mutex
running bool running bool
defaultInterval time.Duration
interval time.Duration
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
@@ -55,6 +53,7 @@ type PeerMonitor struct {
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{} holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields // Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -87,8 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
monitors: make(map[int]*Client), monitors: make(map[int]*Client),
defaultInterval: 2 * time.Second,
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second, timeout: 3 * time.Second,
maxAttempts: 3, maxAttempts: 3,
wsClient: wsClient, wsClient: wsClient,
@@ -118,6 +115,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
holepunchBackoffMultiplier: 1.5, holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int), holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second, holepunchCurrentInterval: 2 * time.Second,
holepunchUpdateChan: make(chan struct{}, 1),
} }
if err := pm.initNetstack(); err != nil { if err := pm.initNetstack(); err != nil {
@@ -133,82 +131,76 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
} }
// SetInterval changes how frequently peers are checked // SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) { func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors // Update interval for all existing monitors
for _, client := range pm.monitors { for _, client := range pm.monitors {
client.SetPacketInterval(interval) client.SetPacketInterval(minInterval, maxInterval)
} }
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
} }
func (pm *PeerMonitor) ResetInterval() { func (pm *PeerMonitor) ResetPeerInterval() {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.interval = pm.defaultInterval
// Update interval for all existing monitors // Update interval for all existing monitors
for _, client := range pm.monitors { for _, client := range pm.monitors {
client.SetPacketInterval(pm.defaultInterval) client.ResetPacketInterval()
} }
} }
// SetTimeout changes the timeout for waiting for responses // SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.timeout = timeout
// Update timeout for all existing monitors
for _, client := range pm.monitors {
client.SetTimeout(timeout)
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.maxAttempts = attempts
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
}
}
// SetHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.holepunchMinInterval = minInterval pm.holepunchMinInterval = minInterval
pm.holepunchMaxInterval = maxInterval pm.holepunchMaxInterval = maxInterval
// Reset current interval to the new minimum // Reset current interval to the new minimum
pm.holepunchCurrentInterval = minInterval pm.holepunchCurrentInterval = minInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
} }
// GetHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring // GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Duration) { func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
return pm.holepunchMinInterval, pm.holepunchMaxInterval return pm.holepunchMinInterval, pm.holepunchMaxInterval
} }
func (pm *PeerMonitor) ResetHolepunchInterval() { func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
} }
// AddPeer adds a new peer to monitor // AddPeer adds a new peer to monitor
@@ -226,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
return err return err
} }
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable
pm.monitors[siteID] = client pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint pm.holepunchEndpoints[siteID] = holepunchEndpoint
@@ -541,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
select { select {
case <-pm.holepunchStopChan: case <-pm.holepunchStopChan:
return return
case <-pm.holepunchUpdateChan:
// Interval settings changed, reset to minimum
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C: case <-timer.C:
anyStatusChanged := pm.checkHolepunchEndpoints() anyStatusChanged := pm.checkHolepunchEndpoints()
@@ -584,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
anyStatusChanged := false anyStatusChanged := false
for siteID, endpoint := range endpoints { for siteID, endpoint := range endpoints {
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout) result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock() pm.mutex.Lock()
@@ -733,55 +729,55 @@ func (pm *PeerMonitor) Close() {
logger.Debug("PeerMonitor: Cleanup complete") logger.Debug("PeerMonitor: Cleanup complete")
} }
// TestPeer tests connectivity to a specific peer // // TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { // func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock() // pm.mutex.Lock()
client, exists := pm.monitors[siteID] // client, exists := pm.monitors[siteID]
pm.mutex.Unlock() // pm.mutex.Unlock()
if !exists { // if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) // return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
} // }
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel() // defer cancel()
connected, rtt := client.TestConnection(ctx) // connected, rtt := client.TestPeerConnection(ctx)
return connected, rtt, nil // return connected, rtt, nil
} // }
// TestAllPeers tests connectivity to all peers // // TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct { // func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
} { // } {
pm.mutex.Lock() // pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors)) // peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors { // for siteID, client := range pm.monitors {
peers[siteID] = client // peers[siteID] = client
} // }
pm.mutex.Unlock() // pm.mutex.Unlock()
results := make(map[int]struct { // results := make(map[int]struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
}) // })
for siteID, client := range peers { // for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx) // connected, rtt := client.TestPeerConnection(ctx)
cancel() // cancel()
results[siteID] = struct { // results[siteID] = struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
}{ // }{
Connected: connected, // Connected: connected,
RTT: rtt, // RTT: rtt,
} // }
} // }
return results // return results
} // }
// initNetstack initializes the gvisor netstack // initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error { func (pm *PeerMonitor) initNetstack() error {

View File

@@ -32,12 +32,15 @@ type Client struct {
monitorLock sync.Mutex monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{} shutdownCh chan struct{}
updateCh chan struct{}
packetInterval time.Duration packetInterval time.Duration
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
dialer Dialer dialer Dialer
// Exponential backoff fields // Exponential backoff fields
defaultMinInterval time.Duration // Default minimum interval (initial)
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
minInterval time.Duration // Minimum interval (initial) minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff) maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check backoffMultiplier float64 // Multiplier for each stable check
@@ -58,7 +61,10 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{ return &Client{
serverAddr: serverAddr, serverAddr: serverAddr,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
updateCh: make(chan struct{}, 1),
packetInterval: 2 * time.Second, packetInterval: 2 * time.Second,
defaultMinInterval: 2 * time.Second,
defaultMaxInterval: 30 * time.Second,
minInterval: 2 * time.Second, minInterval: 2 * time.Second,
maxInterval: 30 * time.Second, maxInterval: 30 * time.Second,
backoffMultiplier: 1.5, backoffMultiplier: 1.5,
@@ -70,29 +76,42 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
} }
// SetPacketInterval changes how frequently packets are sent in monitor mode // SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) { func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
c.packetInterval = interval c.monitorLock.Lock()
c.minInterval = interval c.packetInterval = minInterval
c.minInterval = minInterval
c.maxInterval = maxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
} }
// SetTimeout changes the timeout for waiting for responses func (c *Client) ResetPacketInterval() {
func (c *Client) SetTimeout(timeout time.Duration) { c.monitorLock.Lock()
c.timeout = timeout c.packetInterval = c.defaultMinInterval
} c.minInterval = c.defaultMinInterval
c.maxInterval = c.defaultMaxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// SetMaxAttempts changes the maximum number of attempts for TestConnection // Signal the goroutine to apply the new interval if running
func (c *Client) SetMaxAttempts(attempts int) { if monitorRunning && updateCh != nil {
c.maxAttempts = attempts select {
} case updateCh <- struct{}{}:
default:
// SetMaxInterval sets the maximum backoff interval // Channel full or closed, skip
func (c *Client) SetMaxInterval(interval time.Duration) { }
c.maxInterval = interval }
}
// SetBackoffMultiplier sets the multiplier for exponential backoff
func (c *Client) SetBackoffMultiplier(multiplier float64) {
c.backoffMultiplier = multiplier
} }
// UpdateServerAddr updates the server address and resets the connection // UpdateServerAddr updates the server address and resets the connection
@@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error {
return nil return nil
} }
// TestConnection checks if the connection to the server is working // TestPeerConnection checks if the connection to the server is working
// Returns true if connected, false otherwise // Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
if err := c.ensureConnection(); err != nil { if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err) logger.Warn("Failed to ensure connection: %v", err)
return false, 0 return false, 0
@@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
return c.TestConnection(ctx) return c.TestPeerConnection(ctx)
} }
// MonitorCallback is the function type for connection status change callbacks // MonitorCallback is the function type for connection status change callbacks
@@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
select { select {
case <-c.shutdownCh: case <-c.shutdownCh:
return return
case <-c.updateCh:
// Interval settings changed, reset to minimum
c.monitorLock.Lock()
currentInterval = c.minInterval
c.monitorLock.Unlock()
// Reset backoff state
stableCount = 0
timer.Reset(currentInterval)
logger.Debug("Packet interval updated, reset to %v", currentInterval)
case <-timer.C: case <-timer.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout) ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx) connected, rtt := c.TestPeerConnection(ctx)
cancel() cancel()
statusChanged := connected != lastConnected statusChanged := connected != lastConnected

View File

@@ -236,7 +236,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data, Data: data,
} }
logger.Debug("Sending message: %s, data: %+v", messageType, data) logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock() c.writeMux.Lock()
defer c.writeMux.Unlock() defer c.writeMux.Unlock()
@@ -258,7 +258,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
} }
err := c.SendMessage(messageType, currentData) err := c.SendMessage(messageType, currentData)
if err != nil { if err != nil {
logger.Error("Failed to send message: %v", err) logger.Error("websocket: Failed to send message: %v", err)
} }
count++ count++
} }
@@ -271,7 +271,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
select { select {
case <-ticker.C: case <-ticker.C:
if maxAttempts != -1 && count >= maxAttempts { if maxAttempts != -1 && count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return return
} }
dataMux.Lock() dataMux.Lock()
@@ -353,7 +353,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tlsConfig = &tls.Config{} tlsConfig = &tls.Config{}
} }
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
tokenData := map[string]interface{}{ tokenData := map[string]interface{}{
@@ -382,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
req.Header.Set("X-CSRF-Token", "x-csrf-protection") req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging // print out the request for debugging
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
@@ -399,7 +399,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes // Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
@@ -415,7 +415,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
var tokenResp TokenResponse var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.") logger.Error("websocket: Failed to decode token response.")
return "", nil, fmt.Errorf("failed to decode token response: %w", err) return "", nil, fmt.Errorf("failed to decode token response: %w", err)
} }
@@ -427,7 +427,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
return "", nil, fmt.Errorf("received empty token from server") return "", nil, fmt.Errorf("received empty token from server")
} }
logger.Debug("Received token: %s", tokenResp.Data.Token) logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
} }
@@ -442,7 +442,7 @@ func (c *Client) connectWithRetry() {
if err != nil { if err != nil {
// Check if this is an auth error (401/403) // Check if this is an auth error (401/403)
if authErr, ok := err.(*AuthError); ok { if authErr, ok := err.(*AuthError); ok {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel) // Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil { if c.onAuthError != nil {
c.onAuthError(authErr.StatusCode, authErr.Message) c.onAuthError(authErr.StatusCode, authErr.Message)
@@ -452,7 +452,7 @@ func (c *Client) connectWithRetry() {
continue continue
} }
// For other errors (5xx, network issues), continue retrying // For other errors (5xx, network issues), continue retrying
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval) time.Sleep(c.reconnectInterval)
continue continue
} }
@@ -505,7 +505,7 @@ func (c *Client) establishConnection() error {
// Use new TLS configuration method // Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection") logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS() tlsConfig, err := c.setupTLS()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err) return fmt.Errorf("failed to setup TLS configuration: %w", err)
@@ -519,7 +519,7 @@ func (c *Client) establishConnection() error {
dialer.TLSClientConfig = &tls.Config{} dialer.TLSClientConfig = &tls.Config{}
} }
dialer.TLSClientConfig.InsecureSkipVerify = true dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
conn, _, err := dialer.Dial(u.String(), nil) conn, _, err := dialer.Dial(u.String(), nil)
@@ -537,7 +537,7 @@ func (c *Client) establishConnection() error {
if c.onConnect != nil { if c.onConnect != nil {
if err := c.onConnect(); err != nil { if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err) logger.Error("websocket: OnConnect callback failed: %v", err)
} }
} }
@@ -550,9 +550,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Handle new separate certificate configuration // Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS") logger.Info("websocket: Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key // Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
@@ -563,7 +563,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Load CA certificates for remote validation if specified // Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 { if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles { for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile) caCert, err := os.ReadFile(caFile)
@@ -589,13 +589,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Fallback to existing PKCS12 implementation for backward compatibility // Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" { if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS() return c.setupPKCS12TLS()
} }
// Legacy fallback using config.TlsClientCert // Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" { if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert) return loadClientCertificate(c.config.TlsClientCert)
} }
@@ -630,7 +630,7 @@ func (c *Client) pingMonitor() {
// Expected during shutdown // Expected during shutdown
return return
default: default:
logger.Error("Ping failed: %v", err) logger.Error("websocket: Ping failed: %v", err)
c.reconnect() c.reconnect()
return return
} }
@@ -663,18 +663,23 @@ func (c *Client) readPumpWithDisconnectDetection() {
var msg WSMessage var msg WSMessage
err := c.conn.ReadJSON(&msg) err := c.conn.ReadJSON(&msg)
if err != nil { if err != nil {
// Check if we're shutting down before logging error // Check if we're shutting down or explicitly disconnected before logging error
select { select {
case <-c.done: case <-c.done:
// Expected during shutdown, don't log as error // Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown") logger.Debug("websocket: connection closed during shutdown")
return return
default: default:
// Check if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: connection closed: client was explicitly disconnected")
return
}
// Unexpected error during normal operation // Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err) logger.Error("websocket: read error: %v", err)
} else { } else {
logger.Debug("WebSocket connection closed: %v", err) logger.Debug("websocket: connection closed: %v", err)
} }
return // triggers reconnect via defer return // triggers reconnect via defer
} }
@@ -696,6 +701,12 @@ func (c *Client) reconnect() {
c.conn = nil c.conn = nil
} }
// Don't reconnect if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
return
}
// Only reconnect if we're not shutting down // Only reconnect if we're not shutting down
select { select {
case <-c.done: case <-c.done:
@@ -713,7 +724,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates (PKCS12 format) // LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) { func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path) logger.Info("websocket: Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file // Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path) p12Data, err := os.ReadFile(p12Path)
if err != nil { if err != nil {