Merge branch 'power-state' into dev

Former-commit-id: e2a071e6dc
This commit is contained in:
Owen
2026-01-15 16:39:41 -08:00
13 changed files with 553 additions and 252 deletions

2
go.mod
View File

@@ -30,3 +30,5 @@ require (
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
)
replace github.com/fosrl/newt => ../newt

2
go.sum
View File

@@ -1,7 +1,5 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4=
github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

View File

@@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
MiddleDev: o.middleDev,
LocalIP: interfaceIP,
SharedBind: o.sharedBind,
WSClient: o.olmClient,
WSClient: o.websocket,
APIServer: o.apiServer,
})

View File

@@ -186,12 +186,12 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
}
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second)
}, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}

View File

@@ -42,9 +42,14 @@ type Olm struct {
dnsProxy *dns.DNSProxy
apiServer *api.API
olmClient *websocket.Client
websocket *websocket.Client
holePunchManager *holepunch.Manager
peerManager *peers.PeerManager
// Power mode management
currentPowerMode string
powerModeMu sync.Mutex
wakeUpTimer *time.Timer
wakeUpDebounce time.Duration
olmCtx context.Context
tunnelCancel context.CancelFunc
@@ -58,10 +63,11 @@ type Olm struct {
metaMu sync.Mutex
stopRegister func()
stopPeerSend func()
updateRegister func(newData any)
stopPing chan struct{}
stopServerPing func()
stopPeerSend func()
}
// initTunnelInfo creates the shared UDP socket and holepunch manager.
@@ -134,6 +140,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logFile = file
}
if config.WakeUpDebounce == 0 {
config.WakeUpDebounce = 3 * time.Second
}
logger.Debug("Checking permissions for native interface")
err := permissions.CheckNativeInterfacePermissions()
if err != nil {
@@ -285,9 +295,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
o.tunnelCancel = cancel
// Recreate channels for this tunnel session
o.stopPing = make(chan struct{})
var (
id = config.ID
secret = config.Secret
@@ -343,6 +350,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
o.apiServer.SetConnectionStatus(true)
// restart the ping if we need to
if o.stopServerPing == nil {
o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{
"timestamp": time.Now().Unix(),
"userToken": olmClient.GetConfig().UserToken,
}, 30*time.Second, -1) // -1 means dont time out with the max attempts
}
if o.connected {
logger.Debug("Already connected, skipping registration")
return nil
@@ -364,7 +379,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
}, 1*time.Second)
}, 1*time.Second, 10)
// Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil {
@@ -372,8 +387,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
}
}
go o.keepSendingPing(olmClient)
return nil
})
@@ -446,7 +459,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
}
defer func() { _ = olmClient.Close() }()
o.olmClient = olmClient
o.websocket = olmClient
// Wait for context cancellation
<-tunnelCtx.Done()
@@ -465,9 +478,9 @@ func (o *Olm) Close() {
o.holePunchManager = nil
}
if o.stopPing != nil {
close(o.stopPing)
o.stopPing = nil
if o.stopServerPing != nil {
o.stopServerPing()
o.stopServerPing = nil
}
if o.stopRegister != nil {
@@ -545,9 +558,9 @@ func (o *Olm) StopTunnel() error {
}
// Close the websocket connection
if o.olmClient != nil {
_ = o.olmClient.Close()
o.olmClient = nil
if o.websocket != nil {
_ = o.websocket.Close()
o.websocket = nil
}
o.Close()
@@ -625,3 +638,157 @@ func (o *Olm) SetPostures(data map[string]any) {
o.postures = data
}
// SetPowerMode switches between normal and low power modes
// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes
// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored
// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate
func (o *Olm) SetPowerMode(mode string) error {
// Validate mode
if mode != "normal" && mode != "low" {
return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode)
}
o.powerModeMu.Lock()
defer o.powerModeMu.Unlock()
// If already in the requested mode, return early
if o.currentPowerMode == mode {
// Cancel any pending wake-up timer if we're already in normal mode
if mode == "normal" && o.wakeUpTimer != nil {
o.wakeUpTimer.Stop()
o.wakeUpTimer = nil
}
logger.Debug("Already in %s power mode", mode)
return nil
}
if mode == "low" {
// Low Power Mode: Cancel any pending wake-up and immediately go to sleep
// Cancel pending wake-up timer if any
if o.wakeUpTimer != nil {
logger.Debug("Cancelling pending wake-up timer")
o.wakeUpTimer.Stop()
o.wakeUpTimer = nil
}
logger.Info("Switching to low power mode")
if o.websocket != nil {
logger.Info("Disconnecting websocket for low power mode")
if err := o.websocket.Disconnect(); err != nil {
logger.Error("Error disconnecting websocket: %v", err)
}
}
lowPowerInterval := 10 * time.Minute
if o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil {
peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval)
peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval)
logger.Info("Set monitoring intervals to 10 minutes for low power mode")
}
o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable
}
if o.holePunchManager != nil {
o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval)
}
o.currentPowerMode = "low"
logger.Info("Switched to low power mode")
} else {
// Normal Power Mode: Start debounce timer before actually waking up
// If there's already a pending wake-up timer, don't start another
if o.wakeUpTimer != nil {
logger.Debug("Wake-up already pending, ignoring duplicate request")
return nil
}
logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce)
o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() {
o.powerModeMu.Lock()
defer o.powerModeMu.Unlock()
// Clear the timer reference
o.wakeUpTimer = nil
// Double-check we're still in low power mode (could have changed)
if o.currentPowerMode == "normal" {
logger.Debug("Already in normal mode after debounce, skipping wake-up")
return
}
logger.Info("Debounce complete, switching to normal power mode")
logger.Info("Reconnecting websocket for normal power mode")
if o.websocket != nil {
if err := o.websocket.Connect(); err != nil {
logger.Error("Failed to reconnect websocket: %v", err)
return
}
}
// Restore intervals and reconnect websocket
if o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil {
peerMonitor.ResetPeerHolepunchInterval()
peerMonitor.ResetPeerInterval()
}
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
}
if o.holePunchManager != nil {
o.holePunchManager.ResetServerHolepunchInterval()
}
o.currentPowerMode = "normal"
logger.Info("Switched to normal power mode")
})
}
return nil
}
func (o *Olm) AddDevice(fd uint32) error {
if o.middleDev == nil {
return fmt.Errorf("middle device is not initialized")
}
if o.tunnelConfig.MTU == 0 {
return fmt.Errorf("tunnel MTU is not set")
}
tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU)
if err != nil {
return fmt.Errorf("failed to create TUN device from fd: %v", err)
}
// Update interface name if available
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
o.tunnelConfig.InterfaceName = realInterfaceName
}
// Replace the existing TUN device in the middle device with the new one
o.middleDev.AddDevice(tdev)
logger.Info("Added device from file descriptor %d", fd)
return nil
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}

View File

@@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
_ = o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetInterval()
o.holePunchManager.ResetServerHolepunchInterval()
}
logger.Info("Successfully updated peer for site %d", updateData.SiteId)

View File

@@ -1,57 +0,0 @@
package olm
import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]any{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
}
logger.Debug("Sent ping message")
return nil
}
func (o *Olm) keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := o.sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
}
// Set up ticker for one minute intervals
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-o.stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := o.sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
}
}
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}

View File

@@ -24,6 +24,8 @@ type OlmConfig struct {
Version string
Agent string
WakeUpDebounce time.Duration
// Debugging
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")

View File

@@ -50,6 +50,8 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool
APIServer *api.API
PersistentKeepalive int
}
// NewPeerManager creates a new PeerManager with an internal PeerMonitor
@@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
return peer, ok
}
// GetPeerMonitor returns the internal peer monitor instance
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.peerMonitor
}
func (pm *PeerManager) GetAllPeers() []SiteConfig {
pm.mu.RLock()
defer pm.mu.RUnlock()
@@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err
}
@@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
return nil
}
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
pm.mu.RLock()
defer pm.mu.RUnlock()
pm.PersistentKeepalive = interval
errors := make(map[int]error)
for siteId, peer := range pm.peers {
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
if err != nil {
errors[siteId] = err
}
}
if len(errors) == 0 {
return nil
}
return errors
}
func (pm *PeerManager) RemovePeer(siteId int) error {
pm.mu.Lock()
defer pm.mu.Unlock()
@@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}
@@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil {
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err
}
@@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil {
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
}
}

View File

@@ -31,8 +31,7 @@ type PeerMonitor struct {
monitors map[int]*Client
mutex sync.Mutex
running bool
interval time.Duration
timeout time.Duration
timeout time.Duration
maxAttempts int
wsClient *websocket.Client
@@ -50,11 +49,11 @@ type PeerMonitor struct {
// Holepunch testing fields
sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{}
holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -62,11 +61,13 @@ type PeerMonitor struct {
holepunchFailures map[int]int // siteID -> consecutive failure count
// Exponential backoff fields for holepunch monitor
holepunchMinInterval time.Duration // Minimum interval (initial)
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
holepunchBackoffMultiplier float64 // Multiplier for each stable check
holepunchStableCount map[int]int // siteID -> consecutive stable status count
holepunchCurrentInterval time.Duration // Current interval with backoff applied
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
defaultHolepunchMaxInterval time.Duration
holepunchMinInterval time.Duration // Minimum interval (initial)
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
holepunchBackoffMultiplier float64 // Multiplier for each stable check
holepunchStableCount map[int]int // siteID -> consecutive stable status count
holepunchCurrentInterval time.Duration // Current interval with backoff applied
// Rapid initial test fields
rapidTestInterval time.Duration // interval between rapid test attempts
@@ -85,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{
monitors: make(map[int]*Client),
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second,
maxAttempts: 3,
wsClient: wsClient,
@@ -95,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
nsCtx: ctx,
nsCancel: cancel,
sharedBind: sharedBind,
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
holepunchTimeout: 2 * time.Second, // Faster timeout
holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool),
@@ -109,11 +108,14 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
apiServer: apiServer,
wgConnectionStatus: make(map[int]bool),
// Exponential backoff settings for holepunch monitor
holepunchMinInterval: 2 * time.Second,
holepunchMaxInterval: 30 * time.Second,
holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second,
defaultHolepunchMinInterval: 2 * time.Second,
defaultHolepunchMaxInterval: 30 * time.Second,
holepunchMinInterval: 2 * time.Second,
holepunchMaxInterval: 30 * time.Second,
holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second,
holepunchUpdateChan: make(chan struct{}, 1),
}
if err := pm.initNetstack(); err != nil {
@@ -129,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
}
// SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) {
func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetPacketInterval(interval)
client.SetPacketInterval(minInterval, maxInterval)
}
logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
}
// SetTimeout changes the timeout for waiting for responses
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) {
func (pm *PeerMonitor) ResetPeerInterval() {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.timeout = timeout
// Update timeout for all existing monitors
// Update interval for all existing monitors
for _, client := range pm.monitors {
client.SetTimeout(timeout)
client.ResetPacketInterval()
}
}
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (pm *PeerMonitor) SetMaxAttempts(attempts int) {
// SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
pm.holepunchMinInterval = minInterval
pm.holepunchMaxInterval = maxInterval
// Reset current interval to the new minimum
pm.holepunchCurrentInterval = minInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
defer pm.mutex.Unlock()
pm.maxAttempts = attempts
return pm.holepunchMinInterval, pm.holepunchMaxInterval
}
// Update max attempts for all existing monitors
for _, client := range pm.monitors {
client.SetMaxAttempts(attempts)
func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
pm.mutex.Lock()
pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
@@ -182,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
return err
}
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable
pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint
@@ -497,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
select {
case <-pm.holepunchStopChan:
return
case <-pm.holepunchUpdateChan:
// Interval settings changed, reset to minimum
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C:
anyStatusChanged := pm.checkHolepunchEndpoints()
@@ -540,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
anyStatusChanged := false
for siteID, endpoint := range endpoints {
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint)
logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock()
@@ -689,55 +729,55 @@ func (pm *PeerMonitor) Close() {
logger.Debug("PeerMonitor: Cleanup complete")
}
// TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock()
client, exists := pm.monitors[siteID]
pm.mutex.Unlock()
// // TestPeer tests connectivity to a specific peer
// func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
// pm.mutex.Lock()
// client, exists := pm.monitors[siteID]
// pm.mutex.Unlock()
if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
}
// if !exists {
// return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
// }
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel()
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// defer cancel()
connected, rtt := client.TestConnection(ctx)
return connected, rtt, nil
}
// connected, rtt := client.TestPeerConnection(ctx)
// return connected, rtt, nil
// }
// TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool
RTT time.Duration
} {
pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors {
peers[siteID] = client
}
pm.mutex.Unlock()
// // TestAllPeers tests connectivity to all peers
// func (pm *PeerMonitor) TestAllPeers() map[int]struct {
// Connected bool
// RTT time.Duration
// } {
// pm.mutex.Lock()
// peers := make(map[int]*Client, len(pm.monitors))
// for siteID, client := range pm.monitors {
// peers[siteID] = client
// }
// pm.mutex.Unlock()
results := make(map[int]struct {
Connected bool
RTT time.Duration
})
for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx)
cancel()
// results := make(map[int]struct {
// Connected bool
// RTT time.Duration
// })
// for siteID, client := range peers {
// ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
// connected, rtt := client.TestPeerConnection(ctx)
// cancel()
results[siteID] = struct {
Connected bool
RTT time.Duration
}{
Connected: connected,
RTT: rtt,
}
}
// results[siteID] = struct {
// Connected bool
// RTT time.Duration
// }{
// Connected: connected,
// RTT: rtt,
// }
// }
return results
}
// return results
// }
// initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error {

View File

@@ -32,16 +32,19 @@ type Client struct {
monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{}
updateCh chan struct{}
packetInterval time.Duration
timeout time.Duration
maxAttempts int
dialer Dialer
// Exponential backoff fields
minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check
stableCountToBackoff int // Number of stable checks before backing off
defaultMinInterval time.Duration // Default minimum interval (initial)
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check
stableCountToBackoff int // Number of stable checks before backing off
}
// Dialer is a function that creates a connection
@@ -56,43 +59,59 @@ type ConnectionStatus struct {
// NewClient creates a new connection test client
func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
packetInterval: 2 * time.Second,
minInterval: 2 * time.Second,
maxInterval: 30 * time.Second,
backoffMultiplier: 1.5,
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
serverAddr: serverAddr,
shutdownCh: make(chan struct{}),
updateCh: make(chan struct{}, 1),
packetInterval: 2 * time.Second,
defaultMinInterval: 2 * time.Second,
defaultMaxInterval: 30 * time.Second,
minInterval: 2 * time.Second,
maxInterval: 30 * time.Second,
backoffMultiplier: 1.5,
stableCountToBackoff: 3, // After 3 consecutive same-state results, start backing off
timeout: 500 * time.Millisecond, // Timeout for individual packets
maxAttempts: 3, // Default max attempts
dialer: dialer,
}, nil
}
// SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) {
c.packetInterval = interval
c.minInterval = interval
func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
c.monitorLock.Lock()
c.packetInterval = minInterval
c.minInterval = minInterval
c.maxInterval = maxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// SetTimeout changes the timeout for waiting for responses
func (c *Client) SetTimeout(timeout time.Duration) {
c.timeout = timeout
}
func (c *Client) ResetPacketInterval() {
c.monitorLock.Lock()
c.packetInterval = c.defaultMinInterval
c.minInterval = c.defaultMinInterval
c.maxInterval = c.defaultMaxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// SetMaxAttempts changes the maximum number of attempts for TestConnection
func (c *Client) SetMaxAttempts(attempts int) {
c.maxAttempts = attempts
}
// SetMaxInterval sets the maximum backoff interval
func (c *Client) SetMaxInterval(interval time.Duration) {
c.maxInterval = interval
}
// SetBackoffMultiplier sets the multiplier for exponential backoff
func (c *Client) SetBackoffMultiplier(multiplier float64) {
c.backoffMultiplier = multiplier
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// UpdateServerAddr updates the server address and resets the connection
@@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error {
return nil
}
// TestConnection checks if the connection to the server is working
// TestPeerConnection checks if the connection to the server is working
// Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err)
return false, 0
@@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return c.TestConnection(ctx)
return c.TestPeerConnection(ctx)
}
// MonitorCallback is the function type for connection status change callbacks
@@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
select {
case <-c.shutdownCh:
return
case <-c.updateCh:
// Interval settings changed, reset to minimum
c.monitorLock.Lock()
currentInterval = c.minInterval
c.monitorLock.Unlock()
// Reset backoff state
stableCount = 0
timer.Reset(currentInterval)
logger.Debug("Packet interval updated, reset to %v", currentInterval)
case <-timer.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx)
connected, rtt := c.TestPeerConnection(ctx)
cancel()
statusChanged := connected != lastConnected

View File

@@ -11,7 +11,7 @@ import (
)
// ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error {
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
var endpoint string
if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint)
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
}
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString("persistent_keepalive_interval=5\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config)
@@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
return nil
}
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
config := configBuilder.String()
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
}
return nil
}
func formatEndpoint(endpoint string) string {
if strings.Contains(endpoint, ":") {
return endpoint

View File

@@ -77,6 +77,7 @@ type Client struct {
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
isDisconnected bool // Flag to track if client is intentionally disconnected
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
@@ -87,6 +88,10 @@ type Client struct {
clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved
token string // Cached authentication token
exitNodes []ExitNode // Cached exit nodes from token response
tokenMux sync.RWMutex // Protects token and exitNodes
forceNewToken bool // Flag to force fetching a new token on next connection
}
type ClientOption func(*Client)
@@ -173,6 +178,9 @@ func (c *Client) GetConfig() *Config {
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
if c.isDisconnected {
c.isDisconnected = false
}
go c.connectWithRetry()
return nil
}
@@ -205,9 +213,25 @@ func (c *Client) Close() error {
return nil
}
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
func (c *Client) Disconnect() error {
c.isDisconnected = true
c.setConnected(false)
if c.conn != nil {
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil {
if c.isDisconnected || c.conn == nil {
return fmt.Errorf("not connected")
}
@@ -216,14 +240,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data,
}
logger.Debug("Sending message: %s, data: %+v", messageType, data)
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock()
defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
@@ -231,30 +255,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, currentData) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
send := func() {
if c.isDisconnected || c.conn == nil {
return
}
err := c.SendMessage(messageType, currentData)
if err != nil {
logger.Error("websocket: Failed to send message: %v", err)
}
count++
}
count++
send() // Send immediately
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
if maxAttempts != -1 && count >= maxAttempts {
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
dataMux.Lock()
err = c.SendMessage(messageType, currentData)
send()
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
@@ -277,6 +303,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
case <-stopChan:
return
}
// Suspend sending if disconnected
for c.isDisconnected {
select {
case <-stopChan:
return
case <-time.After(500 * time.Millisecond):
}
}
}
}()
return func() {
@@ -323,7 +357,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tlsConfig = &tls.Config{}
}
tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
tokenData := map[string]interface{}{
@@ -352,7 +386,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request
client := &http.Client{}
@@ -369,7 +403,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
@@ -385,7 +419,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.")
logger.Error("websocket: Failed to decode token response.")
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
}
@@ -397,7 +431,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
return "", nil, fmt.Errorf("received empty token from server")
}
logger.Debug("Received token: %s", tokenResp.Data.Token)
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
}
@@ -412,7 +446,7 @@ func (c *Client) connectWithRetry() {
if err != nil {
// Check if this is an auth error (401/403)
if authErr, ok := err.(*AuthError); ok {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr)
logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil {
c.onAuthError(authErr.StatusCode, authErr.Message)
@@ -422,7 +456,7 @@ func (c *Client) connectWithRetry() {
continue
}
// For other errors (5xx, network issues), continue retrying
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
}
@@ -432,15 +466,25 @@ func (c *Client) connectWithRetry() {
}
func (c *Client) establishConnection() error {
// Get token for authentication
token, exitNodes, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
// Get token for authentication - reuse cached token unless forced to get new one
c.tokenMux.Lock()
needNewToken := c.token == "" || c.forceNewToken
if needNewToken {
token, exitNodes, err := c.getToken()
if err != nil {
c.tokenMux.Unlock()
return fmt.Errorf("failed to get token: %w", err)
}
c.token = token
c.exitNodes = exitNodes
c.forceNewToken = false
if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes)
if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes)
}
}
token := c.token
c.tokenMux.Unlock()
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
@@ -475,7 +519,7 @@ func (c *Client) establishConnection() error {
// Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection")
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS()
if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err)
@@ -489,11 +533,23 @@ func (c *Client) establishConnection() error {
dialer.TLSClientConfig = &tls.Config{}
}
dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
}
conn, _, err := dialer.Dial(u.String(), nil)
conn, resp, err := dialer.Dial(u.String(), nil)
if err != nil {
// Check if this is an unauthorized error (401)
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
// Force getting a new token on next reconnect attempt
c.tokenMux.Lock()
c.forceNewToken = true
c.tokenMux.Unlock()
return &AuthError{
StatusCode: http.StatusUnauthorized,
Message: "WebSocket connection unauthorized",
}
}
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
@@ -507,7 +563,7 @@ func (c *Client) establishConnection() error {
if c.onConnect != nil {
if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err)
logger.Error("websocket: OnConnect callback failed: %v", err)
}
}
@@ -520,9 +576,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
logger.Info("websocket: Loading separate certificate files for mTLS")
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
@@ -533,7 +589,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile)
@@ -559,13 +615,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS()
}
// Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert)
}
@@ -587,7 +643,7 @@ func (c *Client) pingMonitor() {
case <-c.done:
return
case <-ticker.C:
if c.conn == nil {
if c.isDisconnected || c.conn == nil {
return
}
c.writeMux.Lock()
@@ -600,7 +656,7 @@ func (c *Client) pingMonitor() {
// Expected during shutdown
return
default:
logger.Error("Ping failed: %v", err)
logger.Error("websocket: Ping failed: %v", err)
c.reconnect()
return
}
@@ -633,18 +689,24 @@ func (c *Client) readPumpWithDisconnectDetection() {
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
// Check if we're shutting down before logging error
// Check if we're shutting down or explicitly disconnected before logging error
select {
case <-c.done:
// Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown")
logger.Debug("websocket: connection closed during shutdown")
return
default:
// Check if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: connection closed: client was explicitly disconnected")
return
}
// Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err)
logger.Error("websocket: read error: %v", err)
} else {
logger.Debug("WebSocket connection closed: %v", err)
logger.Debug("websocket: connection closed: %v", err)
}
return // triggers reconnect via defer
}
@@ -666,6 +728,12 @@ func (c *Client) reconnect() {
c.conn = nil
}
// Don't reconnect if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
return
}
// Only reconnect if we're not shutting down
select {
case <-c.done:
@@ -683,7 +751,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path)
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path)
if err != nil {