diff --git a/go.mod b/go.mod index 4f42df6..0d6bbcb 100644 --- a/go.mod +++ b/go.mod @@ -30,3 +30,5 @@ require ( golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect ) + +replace github.com/fosrl/newt => ../newt diff --git a/go.sum b/go.sum index a543b5a..f6ca61a 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,5 @@ github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4= -github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/olm/connect.go b/olm/connect.go index 568c731..a610ea4 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { MiddleDev: o.middleDev, LocalIP: interfaceIP, SharedBind: o.sharedBind, - WSClient: o.olmClient, + WSClient: o.websocket, APIServer: o.apiServer, }) diff --git a/olm/data.go b/olm/data.go index 9c8d33f..fe0b36a 100644 --- a/olm/data.go +++ b/olm/data.go @@ -186,12 +186,12 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ "siteId": handshakeData.SiteId, - }, 1*time.Second) + }, 1*time.Second, 10) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) } diff --git a/olm/olm.go b/olm/olm.go index 0810025..8ef6c4a 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -42,9 +42,14 @@ type Olm struct { dnsProxy *dns.DNSProxy apiServer *api.API - olmClient *websocket.Client + websocket *websocket.Client holePunchManager *holepunch.Manager peerManager *peers.PeerManager + // Power mode management + currentPowerMode string + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration olmCtx context.Context tunnelCancel context.CancelFunc @@ -58,10 +63,11 @@ type Olm struct { metaMu sync.Mutex stopRegister func() - stopPeerSend func() updateRegister func(newData any) - stopPing chan struct{} + stopServerPing func() + + stopPeerSend func() } // initTunnelInfo creates the shared UDP socket and holepunch manager. @@ -134,6 +140,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logFile = file } + if config.WakeUpDebounce == 0 { + config.WakeUpDebounce = 3 * time.Second + } + logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() if err != nil { @@ -285,9 +295,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { tunnelCtx, cancel := context.WithCancel(o.olmCtx) o.tunnelCancel = cancel - // Recreate channels for this tunnel session - o.stopPing = make(chan struct{}) - var ( id = config.ID secret = config.Secret @@ -343,6 +350,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetConnectionStatus(true) + // restart the ping if we need to + if o.stopServerPing == nil { + o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": olmClient.GetConfig().UserToken, + }, 30*time.Second, -1) // -1 means dont time out with the max attempts + } + if o.connected { logger.Debug("Already connected, skipping registration") return nil @@ -364,7 +379,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { "userToken": userToken, "fingerprint": o.fingerprint, "postures": o.postures, - }, 1*time.Second) + }, 1*time.Second, 10) // Invoke onRegistered callback if configured if o.olmConfig.OnRegistered != nil { @@ -372,8 +387,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } } - go o.keepSendingPing(olmClient) - return nil }) @@ -446,7 +459,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } defer func() { _ = olmClient.Close() }() - o.olmClient = olmClient + o.websocket = olmClient // Wait for context cancellation <-tunnelCtx.Done() @@ -465,9 +478,9 @@ func (o *Olm) Close() { o.holePunchManager = nil } - if o.stopPing != nil { - close(o.stopPing) - o.stopPing = nil + if o.stopServerPing != nil { + o.stopServerPing() + o.stopServerPing = nil } if o.stopRegister != nil { @@ -545,9 +558,9 @@ func (o *Olm) StopTunnel() error { } // Close the websocket connection - if o.olmClient != nil { - _ = o.olmClient.Close() - o.olmClient = nil + if o.websocket != nil { + _ = o.websocket.Close() + o.websocket = nil } o.Close() @@ -625,3 +638,157 @@ func (o *Olm) SetPostures(data map[string]any) { o.postures = data } + +// SetPowerMode switches between normal and low power modes +// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes +// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate +func (o *Olm) SetPowerMode(mode string) error { + // Validate mode + if mode != "normal" && mode != "low" { + return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) + } + + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // If already in the requested mode, return early + if o.currentPowerMode == mode { + // Cancel any pending wake-up timer if we're already in normal mode + if mode == "normal" && o.wakeUpTimer != nil { + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + logger.Debug("Already in %s power mode", mode) + return nil + } + + if mode == "low" { + // Low Power Mode: Cancel any pending wake-up and immediately go to sleep + + // Cancel pending wake-up timer if any + if o.wakeUpTimer != nil { + logger.Debug("Cancelling pending wake-up timer") + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + + logger.Info("Switching to low power mode") + + if o.websocket != nil { + logger.Info("Disconnecting websocket for low power mode") + if err := o.websocket.Disconnect(); err != nil { + logger.Error("Error disconnecting websocket: %v", err) + } + } + + lowPowerInterval := 10 * time.Minute + + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval) + peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval) + logger.Info("Set monitoring intervals to 10 minutes for low power mode") + } + o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable + } + + if o.holePunchManager != nil { + o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval) + } + + o.currentPowerMode = "low" + logger.Info("Switched to low power mode") + + } else { + // Normal Power Mode: Start debounce timer before actually waking up + + // If there's already a pending wake-up timer, don't start another + if o.wakeUpTimer != nil { + logger.Debug("Wake-up already pending, ignoring duplicate request") + return nil + } + + logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce) + + o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() { + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // Clear the timer reference + o.wakeUpTimer = nil + + // Double-check we're still in low power mode (could have changed) + if o.currentPowerMode == "normal" { + logger.Debug("Already in normal mode after debounce, skipping wake-up") + return + } + + logger.Info("Debounce complete, switching to normal power mode") + + logger.Info("Reconnecting websocket for normal power mode") + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return + } + } + + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetPeerHolepunchInterval() + peerMonitor.ResetPeerInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + if o.holePunchManager != nil { + o.holePunchManager.ResetServerHolepunchInterval() + } + + o.currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + }) + } + + return nil +} + +func (o *Olm) AddDevice(fd uint32) error { + if o.middleDev == nil { + return fmt.Errorf("middle device is not initialized") + } + + if o.tunnelConfig.MTU == 0 { + return fmt.Errorf("tunnel MTU is not set") + } + + tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU) + if err != nil { + return fmt.Errorf("failed to create TUN device from fd: %v", err) + } + + // Update interface name if available + if realInterfaceName, err2 := tdev.Name(); err2 == nil { + o.tunnelConfig.InterfaceName = realInterfaceName + } + + // Replace the existing TUN device in the middle device with the new one + o.middleDev.AddDevice(tdev) + + logger.Info("Added device from file descriptor %d", fd) + + return nil +} + +func GetNetworkSettingsJSON() (string, error) { + return network.GetJSON() +} + +func GetNetworkSettingsIncrementor() int { + return network.GetIncrementor() +} diff --git a/olm/peer.go b/olm/peer.go index 8acec42..9bc842e 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) { if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) _ = o.holePunchManager.TriggerHolePunch() - o.holePunchManager.ResetInterval() + o.holePunchManager.ResetServerHolepunchInterval() } logger.Info("Successfully updated peer for site %d", updateData.SiteId) diff --git a/olm/ping.go b/olm/ping.go deleted file mode 100644 index 460fc38..0000000 --- a/olm/ping.go +++ /dev/null @@ -1,57 +0,0 @@ -package olm - -import ( - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "github.com/fosrl/olm/websocket" -) - -func (o *Olm) sendPing(olm *websocket.Client) error { - err := olm.SendMessage("olm/ping", map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, - "fingerprint": o.fingerprint, - "postures": o.postures, - }) - if err != nil { - logger.Error("Failed to send ping message: %v", err) - return err - } - logger.Debug("Sent ping message") - return nil -} - -func (o *Olm) keepSendingPing(olm *websocket.Client) { - // Send ping immediately on startup - if err := o.sendPing(olm); err != nil { - logger.Error("Failed to send initial ping: %v", err) - } else { - logger.Info("Sent initial ping message") - } - - // Set up ticker for one minute intervals - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-o.stopPing: - logger.Info("Stopping ping messages") - return - case <-ticker.C: - if err := o.sendPing(olm); err != nil { - logger.Error("Failed to send periodic ping: %v", err) - } - } - } -} - -func GetNetworkSettingsJSON() (string, error) { - return network.GetJSON() -} - -func GetNetworkSettingsIncrementor() int { - return network.GetIncrementor() -} diff --git a/olm/types.go b/olm/types.go index 28e2260..bdcea03 100644 --- a/olm/types.go +++ b/olm/types.go @@ -23,6 +23,8 @@ type OlmConfig struct { SocketPath string Version string Agent string + + WakeUpDebounce time.Duration // Debugging PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") diff --git a/peers/manager.go b/peers/manager.go index af781e5..0566775 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -50,6 +50,8 @@ type PeerManager struct { // key is the CIDR string, value is a set of siteIds that want this IP allowedIPClaims map[string]map[int]bool APIServer *api.API + + PersistentKeepalive int } // NewPeerManager creates a new PeerManager with an internal PeerMonitor @@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) { return peer, ok } +// GetPeerMonitor returns the internal peer monitor instance +func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor { + pm.mu.RLock() + defer pm.mu.RUnlock() + return pm.peerMonitor +} + func (pm *PeerManager) GetAllPeers() []SiteConfig { pm.mu.RLock() defer pm.mu.RUnlock() @@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { return nil } +// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once +// without recreating them. Returns a map of siteId to error for any peers that failed to update. +func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + pm.PersistentKeepalive = interval + + errors := make(map[int]error) + + for siteId, peer := range pm.peers { + err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval) + if err != nil { + errors[siteId] = err + } + } + + if len(errors) == 0 { + return nil + } + return errors +} + func (pm *PeerManager) RemovePeer(siteId int) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error { ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) wgConfig := promotedPeer wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } @@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedWgConfig := promotedPeer promotedWgConfig.AllowedIps = promotedOwnedIPs - if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 45dd090..387b82f 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -31,8 +31,7 @@ type PeerMonitor struct { monitors map[int]*Client mutex sync.Mutex running bool - interval time.Duration - timeout time.Duration + timeout time.Duration maxAttempts int wsClient *websocket.Client @@ -50,11 +49,11 @@ type PeerMonitor struct { // Holepunch testing fields sharedBind *bind.SharedBind holepunchTester *holepunch.HolepunchTester - holepunchInterval time.Duration holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status - holepunchStopChan chan struct{} + holepunchStopChan chan struct{} + holepunchUpdateChan chan struct{} // Relay tracking fields relayedPeers map[int]bool // siteID -> whether the peer is currently relayed @@ -62,11 +61,13 @@ type PeerMonitor struct { holepunchFailures map[int]int // siteID -> consecutive failure count // Exponential backoff fields for holepunch monitor - holepunchMinInterval time.Duration // Minimum interval (initial) - holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) - holepunchBackoffMultiplier float64 // Multiplier for each stable check - holepunchStableCount map[int]int // siteID -> consecutive stable status count - holepunchCurrentInterval time.Duration // Current interval with backoff applied + defaultHolepunchMinInterval time.Duration // Minimum interval (initial) + defaultHolepunchMaxInterval time.Duration + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts @@ -85,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), - interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, wsClient: wsClient, @@ -95,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), @@ -109,11 +108,14 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe apiServer: apiServer, wgConnectionStatus: make(map[int]bool), // Exponential backoff settings for holepunch monitor - holepunchMinInterval: 2 * time.Second, - holepunchMaxInterval: 30 * time.Second, - holepunchBackoffMultiplier: 1.5, - holepunchStableCount: make(map[int]int), - holepunchCurrentInterval: 2 * time.Second, + defaultHolepunchMinInterval: 2 * time.Second, + defaultHolepunchMaxInterval: 30 * time.Second, + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, + holepunchUpdateChan: make(chan struct{}, 1), } if err := pm.initNetstack(); err != nil { @@ -129,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe } // SetInterval changes how frequently peers are checked -func (pm *PeerMonitor) SetInterval(interval time.Duration) { +func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.interval = interval - // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetPacketInterval(interval) + client.SetPacketInterval(minInterval, maxInterval) } + + logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval) } -// SetTimeout changes the timeout for waiting for responses -func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { +func (pm *PeerMonitor) ResetPeerInterval() { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.timeout = timeout - - // Update timeout for all existing monitors + // Update interval for all existing monitors for _, client := range pm.monitors { - client.SetTimeout(timeout) + client.ResetPacketInterval() } } -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (pm *PeerMonitor) SetMaxAttempts(attempts int) { +// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) { + pm.mutex.Lock() + pm.holepunchMinInterval = minInterval + pm.holepunchMaxInterval = maxInterval + // Reset current interval to the new minimum + pm.holepunchCurrentInterval = minInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } +} + +// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring +func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) { pm.mutex.Lock() defer pm.mutex.Unlock() - pm.maxAttempts = attempts + return pm.holepunchMinInterval, pm.holepunchMaxInterval +} - // Update max attempts for all existing monitors - for _, client := range pm.monitors { - client.SetMaxAttempts(attempts) +func (pm *PeerMonitor) ResetPeerHolepunchInterval() { + pm.mutex.Lock() + pm.holepunchMinInterval = pm.defaultHolepunchMinInterval + pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval + pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval + updateChan := pm.holepunchUpdateChan + pm.mutex.Unlock() + + logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval) + + // Signal the goroutine to apply the new interval if running + if updateChan != nil { + select { + case updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } } } @@ -182,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st return err } - client.SetPacketInterval(pm.interval) - client.SetTimeout(pm.timeout) - client.SetMaxAttempts(pm.maxAttempts) - client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable - pm.monitors[siteID] = client pm.holepunchEndpoints[siteID] = holepunchEndpoint @@ -497,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() { select { case <-pm.holepunchStopChan: return + case <-pm.holepunchUpdateChan: + // Interval settings changed, reset to minimum + pm.mutex.Lock() + pm.holepunchCurrentInterval = pm.holepunchMinInterval + currentInterval := pm.holepunchCurrentInterval + pm.mutex.Unlock() + + timer.Reset(currentInterval) + logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval) case <-timer.C: anyStatusChanged := pm.checkHolepunchEndpoints() @@ -540,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool { anyStatusChanged := false for siteID, endpoint := range endpoints { - // logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) + logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint) result := pm.holepunchTester.TestEndpoint(endpoint, timeout) pm.mutex.Lock() @@ -689,55 +729,55 @@ func (pm *PeerMonitor) Close() { logger.Debug("PeerMonitor: Cleanup complete") } -// TestPeer tests connectivity to a specific peer -func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { - pm.mutex.Lock() - client, exists := pm.monitors[siteID] - pm.mutex.Unlock() +// // TestPeer tests connectivity to a specific peer +// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { +// pm.mutex.Lock() +// client, exists := pm.monitors[siteID] +// pm.mutex.Unlock() - if !exists { - return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) - } +// if !exists { +// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) +// } - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - defer cancel() +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// defer cancel() - connected, rtt := client.TestConnection(ctx) - return connected, rtt, nil -} +// connected, rtt := client.TestPeerConnection(ctx) +// return connected, rtt, nil +// } -// TestAllPeers tests connectivity to all peers -func (pm *PeerMonitor) TestAllPeers() map[int]struct { - Connected bool - RTT time.Duration -} { - pm.mutex.Lock() - peers := make(map[int]*Client, len(pm.monitors)) - for siteID, client := range pm.monitors { - peers[siteID] = client - } - pm.mutex.Unlock() +// // TestAllPeers tests connectivity to all peers +// func (pm *PeerMonitor) TestAllPeers() map[int]struct { +// Connected bool +// RTT time.Duration +// } { +// pm.mutex.Lock() +// peers := make(map[int]*Client, len(pm.monitors)) +// for siteID, client := range pm.monitors { +// peers[siteID] = client +// } +// pm.mutex.Unlock() - results := make(map[int]struct { - Connected bool - RTT time.Duration - }) - for siteID, client := range peers { - ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) - connected, rtt := client.TestConnection(ctx) - cancel() +// results := make(map[int]struct { +// Connected bool +// RTT time.Duration +// }) +// for siteID, client := range peers { +// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) +// connected, rtt := client.TestPeerConnection(ctx) +// cancel() - results[siteID] = struct { - Connected bool - RTT time.Duration - }{ - Connected: connected, - RTT: rtt, - } - } +// results[siteID] = struct { +// Connected bool +// RTT time.Duration +// }{ +// Connected: connected, +// RTT: rtt, +// } +// } - return results -} +// return results +// } // initNetstack initializes the gvisor netstack func (pm *PeerMonitor) initNetstack() error { diff --git a/peers/monitor/wgtester.go b/peers/monitor/wgtester.go index 21f788a..f06759a 100644 --- a/peers/monitor/wgtester.go +++ b/peers/monitor/wgtester.go @@ -32,16 +32,19 @@ type Client struct { monitorLock sync.Mutex connLock sync.Mutex // Protects connection operations shutdownCh chan struct{} + updateCh chan struct{} packetInterval time.Duration timeout time.Duration maxAttempts int dialer Dialer // Exponential backoff fields - minInterval time.Duration // Minimum interval (initial) - maxInterval time.Duration // Maximum interval (cap for backoff) - backoffMultiplier float64 // Multiplier for each stable check - stableCountToBackoff int // Number of stable checks before backing off + defaultMinInterval time.Duration // Default minimum interval (initial) + defaultMaxInterval time.Duration // Default maximum interval (cap for backoff) + minInterval time.Duration // Minimum interval (initial) + maxInterval time.Duration // Maximum interval (cap for backoff) + backoffMultiplier float64 // Multiplier for each stable check + stableCountToBackoff int // Number of stable checks before backing off } // Dialer is a function that creates a connection @@ -56,43 +59,59 @@ type ConnectionStatus struct { // NewClient creates a new connection test client func NewClient(serverAddr string, dialer Dialer) (*Client, error) { return &Client{ - serverAddr: serverAddr, - shutdownCh: make(chan struct{}), - packetInterval: 2 * time.Second, - minInterval: 2 * time.Second, - maxInterval: 30 * time.Second, - backoffMultiplier: 1.5, - stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off - timeout: 500 * time.Millisecond, // Timeout for individual packets - maxAttempts: 3, // Default max attempts - dialer: dialer, + serverAddr: serverAddr, + shutdownCh: make(chan struct{}), + updateCh: make(chan struct{}, 1), + packetInterval: 2 * time.Second, + defaultMinInterval: 2 * time.Second, + defaultMaxInterval: 30 * time.Second, + minInterval: 2 * time.Second, + maxInterval: 30 * time.Second, + backoffMultiplier: 1.5, + stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off + timeout: 500 * time.Millisecond, // Timeout for individual packets + maxAttempts: 3, // Default max attempts + dialer: dialer, }, nil } // SetPacketInterval changes how frequently packets are sent in monitor mode -func (c *Client) SetPacketInterval(interval time.Duration) { - c.packetInterval = interval - c.minInterval = interval +func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) { + c.monitorLock.Lock() + c.packetInterval = minInterval + c.minInterval = minInterval + c.maxInterval = maxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() + + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } -// SetTimeout changes the timeout for waiting for responses -func (c *Client) SetTimeout(timeout time.Duration) { - c.timeout = timeout -} +func (c *Client) ResetPacketInterval() { + c.monitorLock.Lock() + c.packetInterval = c.defaultMinInterval + c.minInterval = c.defaultMinInterval + c.maxInterval = c.defaultMaxInterval + updateCh := c.updateCh + monitorRunning := c.monitorRunning + c.monitorLock.Unlock() -// SetMaxAttempts changes the maximum number of attempts for TestConnection -func (c *Client) SetMaxAttempts(attempts int) { - c.maxAttempts = attempts -} - -// SetMaxInterval sets the maximum backoff interval -func (c *Client) SetMaxInterval(interval time.Duration) { - c.maxInterval = interval -} - -// SetBackoffMultiplier sets the multiplier for exponential backoff -func (c *Client) SetBackoffMultiplier(multiplier float64) { - c.backoffMultiplier = multiplier + // Signal the goroutine to apply the new interval if running + if monitorRunning && updateCh != nil { + select { + case updateCh <- struct{}{}: + default: + // Channel full or closed, skip + } + } } // UpdateServerAddr updates the server address and resets the connection @@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error { return nil } -// TestConnection checks if the connection to the server is working +// TestPeerConnection checks if the connection to the server is working // Returns true if connected, false otherwise -func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { +func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) { + logger.Debug("wgtester: testing connection to peer %s", c.serverAddr) if err := c.ensureConnection(); err != nil { logger.Warn("Failed to ensure connection: %v", err) return false, 0 @@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - return c.TestConnection(ctx) + return c.TestPeerConnection(ctx) } // MonitorCallback is the function type for connection status change callbacks @@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error { select { case <-c.shutdownCh: return + case <-c.updateCh: + // Interval settings changed, reset to minimum + c.monitorLock.Lock() + currentInterval = c.minInterval + c.monitorLock.Unlock() + + // Reset backoff state + stableCount = 0 + + timer.Reset(currentInterval) + logger.Debug("Packet interval updated, reset to %v", currentInterval) case <-timer.C: ctx, cancel := context.WithTimeout(context.Background(), c.timeout) - connected, rtt := c.TestConnection(ctx) + connected, rtt := c.TestPeerConnection(ctx) cancel() statusChanged := connected != lastConnected @@ -321,4 +352,4 @@ func (c *Client) StopMonitor() { close(c.shutdownCh) c.monitorRunning = false -} \ No newline at end of file +} diff --git a/peers/peer.go b/peers/peer.go index 9370b9d..8211fa4 100644 --- a/peers/peer.go +++ b/peers/peer.go @@ -11,7 +11,7 @@ import ( ) // ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error { var endpoint string if relay && siteConfig.RelayEndpoint != "" { endpoint = formatEndpoint(siteConfig.RelayEndpoint) @@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=5\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive)) config := configBuilder.String() logger.Debug("Configuring peer with config: %s", config) @@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [ return nil } +// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it +func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval)) + + config := configBuilder.String() + logger.Debug("Updating persistent keepalive for peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err) + } + + return nil +} + func formatEndpoint(endpoint string) string { if strings.Contains(endpoint, ":") { return endpoint diff --git a/websocket/client.go b/websocket/client.go index 1c5afaf..b50cf31 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -77,6 +77,7 @@ type Client struct { handlersMux sync.RWMutex reconnectInterval time.Duration isConnected bool + isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration pingTimeout time.Duration @@ -87,6 +88,10 @@ type Client struct { clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig configNeedsSave bool // Flag to track if config needs to be saved + token string // Cached authentication token + exitNodes []ExitNode // Cached exit nodes from token response + tokenMux sync.RWMutex // Protects token and exitNodes + forceNewToken bool // Flag to force fetching a new token on next connection } type ClientOption func(*Client) @@ -173,6 +178,9 @@ func (c *Client) GetConfig() *Config { // Connect establishes the WebSocket connection func (c *Client) Connect() error { + if c.isDisconnected { + c.isDisconnected = false + } go c.connectWithRetry() return nil } @@ -205,9 +213,25 @@ func (c *Client) Close() error { return nil } +// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later. +func (c *Client) Disconnect() error { + c.isDisconnected = true + c.setConnected(false) + + if c.conn != nil { + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + err := c.conn.Close() + c.conn = nil + return err + } + return nil +} + // SendMessage sends a message through the WebSocket connection func (c *Client) SendMessage(messageType string, data interface{}) error { - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return fmt.Errorf("not connected") } @@ -216,14 +240,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { Data: data, } - logger.Debug("Sending message: %s, data: %+v", messageType, data) + logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data) c.writeMux.Lock() defer c.writeMux.Unlock() return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) updateChan := make(chan interface{}) var dataMux sync.Mutex @@ -231,30 +255,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter go func() { count := 0 - maxAttempts := 10 - err := c.SendMessage(messageType, currentData) // Send immediately - if err != nil { - logger.Error("Failed to send initial message: %v", err) + send := func() { + if c.isDisconnected || c.conn == nil { + return + } + err := c.SendMessage(messageType, currentData) + if err != nil { + logger.Error("websocket: Failed to send message: %v", err) + } + count++ } - count++ + + send() // Send immediately ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: - if count >= maxAttempts { - logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) + if maxAttempts != -1 && count >= maxAttempts { + logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() - err = c.SendMessage(messageType, currentData) + send() dataMux.Unlock() - if err != nil { - logger.Error("Failed to send message: %v", err) - } - count++ case newData := <-updateChan: dataMux.Lock() // Merge newData into currentData if both are maps @@ -277,6 +303,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter case <-stopChan: return } + // Suspend sending if disconnected + for c.isDisconnected { + select { + case <-stopChan: + return + case <-time.After(500 * time.Millisecond): + } + } } }() return func() { @@ -323,7 +357,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { tlsConfig = &tls.Config{} } tlsConfig.InsecureSkipVerify = true - logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } tokenData := map[string]interface{}{ @@ -352,7 +386,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { req.Header.Set("X-CSRF-Token", "x-csrf-protection") // print out the request for debugging - logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) + logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) // Make the request client := &http.Client{} @@ -369,7 +403,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) - logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) // Return AuthError for 401/403 status codes if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { @@ -385,7 +419,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { - logger.Error("Failed to decode token response.") + logger.Error("websocket: Failed to decode token response.") return "", nil, fmt.Errorf("failed to decode token response: %w", err) } @@ -397,7 +431,7 @@ func (c *Client) getToken() (string, []ExitNode, error) { return "", nil, fmt.Errorf("received empty token from server") } - logger.Debug("Received token: %s", tokenResp.Data.Token) + logger.Debug("websocket: Received token: %s", tokenResp.Data.Token) return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } @@ -412,7 +446,7 @@ func (c *Client) connectWithRetry() { if err != nil { // Check if this is an auth error (401/403) if authErr, ok := err.(*AuthError); ok { - logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) + logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr) // Trigger auth error callback if set (this should terminate the tunnel) if c.onAuthError != nil { c.onAuthError(authErr.StatusCode, authErr.Message) @@ -422,7 +456,7 @@ func (c *Client) connectWithRetry() { continue } // For other errors (5xx, network issues), continue retrying - logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) time.Sleep(c.reconnectInterval) continue } @@ -432,15 +466,25 @@ func (c *Client) connectWithRetry() { } func (c *Client) establishConnection() error { - // Get token for authentication - token, exitNodes, err := c.getToken() - if err != nil { - return fmt.Errorf("failed to get token: %w", err) - } - - if c.onTokenUpdate != nil { - c.onTokenUpdate(token, exitNodes) + // Get token for authentication - reuse cached token unless forced to get new one + c.tokenMux.Lock() + needNewToken := c.token == "" || c.forceNewToken + if needNewToken { + token, exitNodes, err := c.getToken() + if err != nil { + c.tokenMux.Unlock() + return fmt.Errorf("failed to get token: %w", err) + } + c.token = token + c.exitNodes = exitNodes + c.forceNewToken = false + + if c.onTokenUpdate != nil { + c.onTokenUpdate(token, exitNodes) + } } + token := c.token + c.tokenMux.Unlock() // Parse the base URL to determine protocol and hostname baseURL, err := url.Parse(c.baseURL) @@ -475,7 +519,7 @@ func (c *Client) establishConnection() error { // Use new TLS configuration method if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { - logger.Info("Setting up TLS configuration for WebSocket connection") + logger.Info("websocket: Setting up TLS configuration for WebSocket connection") tlsConfig, err := c.setupTLS() if err != nil { return fmt.Errorf("failed to setup TLS configuration: %w", err) @@ -489,11 +533,23 @@ func (c *Client) establishConnection() error { dialer.TLSClientConfig = &tls.Config{} } dialer.TLSClientConfig.InsecureSkipVerify = true - logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } - conn, _, err := dialer.Dial(u.String(), nil) + conn, resp, err := dialer.Dial(u.String(), nil) if err != nil { + // Check if this is an unauthorized error (401) + if resp != nil && resp.StatusCode == http.StatusUnauthorized { + logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized") + // Force getting a new token on next reconnect attempt + c.tokenMux.Lock() + c.forceNewToken = true + c.tokenMux.Unlock() + return &AuthError{ + StatusCode: http.StatusUnauthorized, + Message: "WebSocket connection unauthorized", + } + } return fmt.Errorf("failed to connect to WebSocket: %w", err) } @@ -507,7 +563,7 @@ func (c *Client) establishConnection() error { if c.onConnect != nil { if err := c.onConnect(); err != nil { - logger.Error("OnConnect callback failed: %v", err) + logger.Error("websocket: OnConnect callback failed: %v", err) } } @@ -520,9 +576,9 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Handle new separate certificate configuration if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { - logger.Info("Loading separate certificate files for mTLS") - logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) - logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + logger.Info("websocket: Loading separate certificate files for mTLS") + logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile) // Load client certificate and key cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) @@ -533,7 +589,7 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Load CA certificates for remote validation if specified if len(c.tlsConfig.CAFiles) > 0 { - logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles) caCertPool := x509.NewCertPool() for _, caFile := range c.tlsConfig.CAFiles { caCert, err := os.ReadFile(caFile) @@ -559,13 +615,13 @@ func (c *Client) setupTLS() (*tls.Config, error) { // Fallback to existing PKCS12 implementation for backward compatibility if c.tlsConfig.PKCS12File != "" { - logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)") return c.setupPKCS12TLS() } // Legacy fallback using config.TlsClientCert if c.config.TlsClientCert != "" { - logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)") return loadClientCertificate(c.config.TlsClientCert) } @@ -587,7 +643,7 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return } c.writeMux.Lock() @@ -600,7 +656,7 @@ func (c *Client) pingMonitor() { // Expected during shutdown return default: - logger.Error("Ping failed: %v", err) + logger.Error("websocket: Ping failed: %v", err) c.reconnect() return } @@ -633,18 +689,24 @@ func (c *Client) readPumpWithDisconnectDetection() { var msg WSMessage err := c.conn.ReadJSON(&msg) if err != nil { - // Check if we're shutting down before logging error + // Check if we're shutting down or explicitly disconnected before logging error select { case <-c.done: // Expected during shutdown, don't log as error - logger.Debug("WebSocket connection closed during shutdown") + logger.Debug("websocket: connection closed during shutdown") return default: + // Check if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: connection closed: client was explicitly disconnected") + return + } + // Unexpected error during normal operation if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { - logger.Error("WebSocket read error: %v", err) + logger.Error("websocket: read error: %v", err) } else { - logger.Debug("WebSocket connection closed: %v", err) + logger.Debug("websocket: connection closed: %v", err) } return // triggers reconnect via defer } @@ -666,6 +728,12 @@ func (c *Client) reconnect() { c.conn = nil } + // Don't reconnect if explicitly disconnected + if c.isDisconnected { + logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected") + return + } + // Only reconnect if we're not shutting down select { case <-c.done: @@ -683,7 +751,7 @@ func (c *Client) setConnected(status bool) { // LoadClientCertificate Helper method to load client certificates (PKCS12 format) func loadClientCertificate(p12Path string) (*tls.Config, error) { - logger.Info("Loading tls-client-cert %s", p12Path) + logger.Info("websocket: Loading tls-client-cert %s", p12Path) // Read the PKCS12 file p12Data, err := os.ReadFile(p12Path) if err != nil {