From 3ba171452488a9a944687617a509009edce84a83 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 16:38:40 -0800 Subject: [PATCH] Power state getting set correctly Former-commit-id: 0895156efd764c365b9196e55dcb1199b3ec9b1c --- go.mod | 2 + go.sum | 2 - olm/data.go | 2 +- olm/olm.go | 56 ++++++----- olm/peer.go | 2 +- peers/monitor/monitor.go | 198 +++++++++++++++++++------------------- peers/monitor/wgtester.go | 109 +++++++++++++-------- websocket/client.go | 61 +++++++----- 8 files changed, 239 insertions(+), 193 deletions(-) diff --git a/go.mod b/go.mod index 4f42df6..0d6bbcb 100644 --- a/go.mod +++ b/go.mod @@ -30,3 +30,5 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) + +replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index a543b5a..f6ca61a 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4= -github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/olm/data.go b/olm/data.go index 93e64d0..fe0b36a 100644 --- a/olm/data.go +++ b/olm/data.go @@ -186,7 +186,7 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud // Send handshake acknowledgment back to server with retry o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ diff --git a/olm/olm.go b/olm/olm.go index 6a0a26f..3f197ae 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -46,10 +46,10 @@ type Olm struct { holePunchManager *holepunch.Manager peerManager *peers.PeerManager // Power mode management - currentPowerMode string - powerModeMu sync.Mutex - wakeUpTimer *time.Timer - wakeUpDebounce time.Duration + currentPowerMode string + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration olmCtx context.Context tunnelCancel context.CancelFunc @@ -134,7 +134,7 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.SetOutput(file) logFile = file } - + if config.WakeUpDebounce == 0 { config.WakeUpDebounce = 3 * time.Second } @@ -628,23 +628,28 @@ func (o *Olm) SetPowerMode(mode string) error { logger.Info("Switching to low power mode") if o.websocket != nil { - logger.Info("Closing websocket connection for low power mode") - if err := o.websocket.Close(); err != nil { - logger.Error("Error closing websocket: %v", err) + logger.Info("Disconnecting websocket for low power mode") + if err := o.websocket.Disconnect(); err != nil { + logger.Error("Error disconnecting websocket: %v", err) } } + lowPowerInterval := 10 * time.Minute + if o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { - lowPowerInterval := 10 * time.Minute - peerMonitor.SetInterval(lowPowerInterval) - peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval) logger.Info("Set monitoring intervals to 10 minutes for low power mode") } o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable } + if o.holePunchManager != nil { + o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval) + } + o.currentPowerMode = "low" logger.Info("Switched to low power mode") @@ -673,20 +678,8 @@ func (o *Olm) SetPowerMode(mode string) error { } logger.Info("Debounce complete, switching to normal power mode") - - // Restore intervals and reconnect websocket - if o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - peerMonitor.ResetHolepunchInterval() - peerMonitor.ResetInterval() - } - - o.peerManager.UpdateAllPeersPersistentKeepalive(5) - } - + logger.Info("Reconnecting websocket for normal power mode") - if o.websocket != nil { if err := o.websocket.Connect(); err != nil { logger.Error("Failed to reconnect websocket: %v", err) @@ -694,6 +687,21 @@ func (o *Olm) SetPowerMode(mode string) error { } } + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetPeerHolepunchInterval() + peerMonitor.ResetPeerInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + if o.holePunchManager != nil { + o.holePunchManager.ResetServerHolepunchInterval() + } + o.currentPowerMode = "normal" logger.Info("Switched to normal power mode") }) diff --git a/olm/peer.go b/olm/peer.go index 8acec42..9bc842e 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) _ = o.holePunchManager.TriggerHolePunch() - o.holePunchManager.ResetInterval() + o.holePunchManager.ResetServerHolepunchInterval() } logger.Info("Successfully updated peer for site %d", updateData.SiteId) diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 3ac4b54..387b82f 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -28,14 +28,12 @@ import ( // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { - monitors map[int]*Client - mutex sync.Mutex - running bool - defaultInterval time.Duration - interval time.Duration + monitors map[int]*Client + mutex sync.Mutex + running bool timeout time.Duration - maxAttempts int - wsClient *websocket.Client + maxAttempts int + wsClient *websocket.Client // Netstack fields middleDev *middleDevice.MiddleDevice @@ -54,7 +52,8 @@ type PeerMonitor struct { holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed @@ -87,8 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - defaultInterval: 2 * time.Second, - interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, @@ -118,6 +115,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe holepunchBackoffMultiplier: 1.5, holepunchStableCount: make(map[int]int), holepunchCurrentInterval: 2 * time.Second, + holepunchUpdateChan: make(chan struct{}, 1), } if err := pm.initNetstack(); err != nil { @@ -133,82 +131,76 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe } // SetInterval changes how frequently peers are checked -func (pm *PeerMonitor) SetInterval(interval time.Duration) { +func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = interval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(interval) + client.SetPacketInterval(minInterval, maxInterval) } + + logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval) } -func (pm *PeerMonitor) ResetInterval() { +func (pm *PeerMonitor) ResetPeerInterval() { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = pm.defaultInterval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(pm.defaultInterval) + client.ResetPacketInterval() } } -// SetTimeout changes the timeout for waiting for responses -func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { +// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.timeout = timeout - - // Update timeout for all existing monitors - for _, client := range pm.monitors { - client.SetTimeout(timeout) - } -} - -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (pm *PeerMonitor) SetMaxAttempts(attempts int) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - - pm.maxAttempts = attempts - - // Update max attempts for all existing monitors - for _, client := range pm.monitors { - client.SetMaxAttempts(attempts) - } -} - -// SetHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring -func (pm *PeerMonitor) SetHolepunchInterval(minInterval, maxInterval time.Duration) { - pm.mutex.Lock() - defer pm.mutex.Unlock() - pm.holepunchMinInterval = minInterval pm.holepunchMaxInterval = maxInterval // Reset current interval to the new minimum pm.holepunchCurrentInterval = minInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// GetHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring -func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Duration) { +// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() return pm.holepunchMinInterval, pm.holepunchMaxInterval } -func (pm *PeerMonitor) ResetHolepunchInterval() { +func (pm *PeerMonitor) ResetPeerHolepunchInterval() { pm.mutex.Lock() - defer pm.mutex.Unlock() - pm.holepunchMinInterval = pm.defaultHolepunchMinInterval pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // AddPeer adds a new peer to monitor @@ -226,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st return err } - client.SetPacketInterval(pm.interval) - client.SetTimeout(pm.timeout) - client.SetMaxAttempts(pm.maxAttempts) - client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable - pm.monitors[siteID] = client pm.holepunchEndpoints[siteID] = holepunchEndpoint @@ -541,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() { select { case <-pm.holepunchStopChan: return + case <-pm.holepunchUpdateChan: + // Interval settings changed, reset to minimum + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) + logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) case <-timer.C: anyStatusChanged := pm.checkHolepunchEndpoints() @@ -584,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool { anyStatusChanged := false for siteID, endpoint := range endpoints { - // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) + logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() @@ -733,55 +729,55 @@ func (pm *PeerMonitor) Close() { logger.Debug("PeerMonitor: Cleanup complete") } -// TestPeer tests connectivity to a specific peer -func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { - pm.mutex.Lock() - client, exists := pm.monitors[siteID] - pm.mutex.Unlock() +// // TestPeer tests connectivity to a specific peer +// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { +// pm.mutex.Lock() +// client, exists := pm.monitors[siteID] +// pm.mutex.Unlock() - if !exists { - return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) - } +// if !exists { +// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) +// } - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - defer cancel() +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// defer cancel() - connected, rtt := client.TestConnection(ctx) - return connected, rtt, nil -} +// connected, rtt := client.TestPeerConnection(ctx) +// return connected, rtt, nil +// } -// TestAllPeers tests connectivity to all peers -func (pm *PeerMonitor) TestAllPeers() map[int]struct { - Connected bool - RTT time.Duration -} { - pm.mutex.Lock() - peers := make(map[int]*Client, len(pm.monitors)) - for siteID, client := range pm.monitors { - peers[siteID] = client - } - pm.mutex.Unlock() +// // TestAllPeers tests connectivity to all peers +// func (pm *PeerMonitor) TestAllPeers() map[int]struct { +// Connected bool +// RTT time.Duration +// } { +// pm.mutex.Lock() +// peers := make(map[int]*Client, len(pm.monitors)) +// for siteID, client := range pm.monitors { +// peers[siteID] = client +// } +// pm.mutex.Unlock() - results := make(map[int]struct { - Connected bool - RTT time.Duration - }) - for siteID, client := range peers { - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - connected, rtt := client.TestConnection(ctx) - cancel() +// results := make(map[int]struct { +// Connected bool +// RTT time.Duration +// }) +// for siteID, client := range peers { +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// connected, rtt := client.TestPeerConnection(ctx) +// cancel() - results[siteID] = struct { - Connected bool - RTT time.Duration - }{ - Connected: connected, - RTT: rtt, - } - } +// results[siteID] = struct { +// Connected bool +// RTT time.Duration +// }{ +// Connected: connected, +// RTT: rtt, +// } +// } - return results -} +// return results +// } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index 21f788a..f06759a 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -32,16 +32,19 @@ type Client struct { monitorLock sync.Mutex connLock sync.Mutex // Protects connection operations shutdownCh chan struct{} + updateCh chan struct{} packetInterval time.Duration timeout time.Duration maxAttempts int dialer Dialer // Exponential backoff fields - minInterval time.Duration // Minimum interval (initial) - maxInterval time.Duration // Maximum interval (cap for backoff) - backoffMultiplier float64 // Multiplier for each stable check - stableCountToBackoff int // Number of stable checks before backing off + defaultMinInterval time.Duration // Default minimum interval (initial) + defaultMaxInterval time.Duration // Default maximum interval (cap for backoff) + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -56,43 +59,59 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - minInterval: 2 * time.Second, - maxInterval: 30 * time.Second, - backoffMultiplier: 1.5, - stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + packetInterval: 2 * time.Second, + defaultMinInterval: 2 * time.Second, + defaultMaxInterval: 30 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode -func (c *Client) SetPacketInterval(interval time.Duration) { - c.packetInterval = interval - c.minInterval = interval +func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) { + c.monitorLock.Lock() + c.packetInterval = minInterval + c.minInterval = minInterval + c.maxInterval = maxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() + + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// SetTimeout changes the timeout for waiting for responses -func (c *Client) SetTimeout(timeout time.Duration) { - c.timeout = timeout -} +func (c *Client) ResetPacketInterval() { + c.monitorLock.Lock() + c.packetInterval = c.defaultMinInterval + c.minInterval = c.defaultMinInterval + c.maxInterval = c.defaultMaxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (c *Client) SetMaxAttempts(attempts int) { - c.maxAttempts = attempts -} - -// SetMaxInterval sets the maximum backoff interval -func (c *Client) SetMaxInterval(interval time.Duration) { - c.maxInterval = interval -} - -// SetBackoffMultiplier sets the multiplier for exponential backoff -func (c *Client) SetBackoffMultiplier(multiplier float64) { - c.backoffMultiplier = multiplier + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // UpdateServerAddr updates the server address and resets the connection @@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error { return nil } -// TestConnection checks if the connection to the server is working +// TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise -func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { +func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { + logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 @@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return c.TestConnection(ctx) + return c.TestPeerConnection(ctx) } // MonitorCallback is the function type for connection status change callbacks @@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { select { case <-c.shutdownCh: return + case <-c.updateCh: + // Interval settings changed, reset to minimum + c.monitorLock.Lock() + currentInterval = c.minInterval + c.monitorLock.Unlock() + + // Reset backoff state + stableCount = 0 + + timer.Reset(currentInterval) + logger.Debug("Packet interval updated, reset to %v", currentInterval) case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) - connected, rtt := c.TestConnection(ctx) + connected, rtt := c.TestPeerConnection(ctx) cancel() statusChanged := connected != lastConnected @@ -321,4 +352,4 @@ func (c *Client) StopMonitor() { close(c.shutdownCh) c.monitorRunning = false -} \ No newline at end of file +} diff --git a/websocket/client.go b/websocket/client.go index 34eea35..f040aa4 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -236,7 +236,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } - logger.Debug("Sending message: %s, data: %+v", messageType, data) + logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data) c.writeMux.Lock() defer c.writeMux.Unlock() @@ -258,7 +258,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter } err := c.SendMessage(messageType, currentData) if err != nil { - logger.Error("Failed to send message: %v", err) + logger.Error("websocket: Failed to send message: %v", err) } count++ } @@ -271,7 +271,7 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter select { case <-ticker.C: if maxAttempts != -1 && count >= maxAttempts { - logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) + logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() @@ -353,7 +353,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { tlsConfig = &tls.Config{} } tlsConfig.InsecureSkipVerify = true - logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } tokenData := map[string]interface{}{ @@ -382,7 +382,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { req.Header.Set("X-CSRF-Token", "x-csrf-protection") // print out the request for debugging - logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) + logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) // Make the request client := &http.Client{} @@ -399,7 +399,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) // Return AuthError for 401/403 status codes if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { @@ -415,7 +415,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - logger.Error("Failed to decode token response.") + logger.Error("websocket: Failed to decode token response.") return "", nil, fmt.Errorf("failed to decode token response: %w", err) } @@ -427,7 +427,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { return "", nil, fmt.Errorf("received empty token from server") } - logger.Debug("Received token: %s", tokenResp.Data.Token) + logger.Debug("websocket: Received token: %s", tokenResp.Data.Token) return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } @@ -442,7 +442,7 @@ func (c *Client) connectWithRetry() { if err != nil { // Check if this is an auth error (401/403) if authErr, ok := err.(*AuthError); ok { - logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) + logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr) // Trigger auth error callback if set (this should terminate the tunnel) if c.onAuthError != nil { c.onAuthError(authErr.StatusCode, authErr.Message) @@ -452,7 +452,7 @@ func (c *Client) connectWithRetry() { continue } // For other errors (5xx, network issues), continue retrying - logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue } @@ -505,7 +505,7 @@ func (c *Client) establishConnection() error { // Use new TLS configuration method if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { - logger.Info("Setting up TLS configuration for WebSocket connection") + logger.Info("websocket: Setting up TLS configuration for WebSocket connection") tlsConfig, err := c.setupTLS() if err != nil { return fmt.Errorf("failed to setup TLS configuration: %w", err) @@ -519,7 +519,7 @@ func (c *Client) establishConnection() error { dialer.TLSClientConfig = &tls.Config{} } dialer.TLSClientConfig.InsecureSkipVerify = true - logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } conn, _, err := dialer.Dial(u.String(), nil) @@ -537,7 +537,7 @@ func (c *Client) establishConnection() error { if c.onConnect != nil { if err := c.onConnect(); err != nil { - logger.Error("OnConnect callback failed: %v", err) + logger.Error("websocket: OnConnect callback failed: %v", err) } } @@ -550,9 +550,9 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Handle new separate certificate configuration if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { - logger.Info("Loading separate certificate files for mTLS") - logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) - logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + logger.Info("websocket: Loading separate certificate files for mTLS") + logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile) // Load client certificate and key cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) @@ -563,7 +563,7 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Load CA certificates for remote validation if specified if len(c.tlsConfig.CAFiles) > 0 { - logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles) caCertPool := x509.NewCertPool() for _, caFile := range c.tlsConfig.CAFiles { caCert, err := os.ReadFile(caFile) @@ -589,13 +589,13 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Fallback to existing PKCS12 implementation for backward compatibility if c.tlsConfig.PKCS12File != "" { - logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)") return c.setupPKCS12TLS() } // Legacy fallback using config.TlsClientCert if c.config.TlsClientCert != "" { - logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)") return loadClientCertificate(c.config.TlsClientCert) } @@ -630,7 +630,7 @@ func (c *Client) pingMonitor() { // Expected during shutdown return default: - logger.Error("Ping failed: %v", err) + logger.Error("websocket: Ping failed: %v", err) c.reconnect() return } @@ -663,18 +663,23 @@ func (c *Client) readPumpWithDisconnectDetection() { var msg WSMessage err := c.conn.ReadJSON(&msg) if err != nil { - // Check if we're shutting down before logging error + // Check if we're shutting down or explicitly disconnected before logging error select { case <-c.done: // Expected during shutdown, don't log as error - logger.Debug("WebSocket connection closed during shutdown") + logger.Debug("websocket: connection closed during shutdown") return default: + // Check if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: connection closed: client was explicitly disconnected") + return + } // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - logger.Error("WebSocket read error: %v", err) + logger.Error("websocket: read error: %v", err) } else { - logger.Debug("WebSocket connection closed: %v", err) + logger.Debug("websocket: connection closed: %v", err) } return // triggers reconnect via defer } @@ -696,6 +701,12 @@ func (c *Client) reconnect() { c.conn = nil } + // Don't reconnect if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected") + return + } + // Only reconnect if we're not shutting down select { case <-c.done: @@ -713,7 +724,7 @@ func (c *Client) setConnected(status bool) { // LoadClientCertificate Helper method to load client certificates (PKCS12 format) func loadClientCertificate(p12Path string) (*tls.Config, error) { - logger.Info("Loading tls-client-cert %s", p12Path) + logger.Info("websocket: Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) if err != nil {