Merge branch 'power-state' into dev

Former-commit-id: e2a071e6dc
This commit is contained in:
Owen
2026-01-15 16:39:41 -08:00
13 changed files with 553 additions and 252 deletions

2
go.mod
View File

@@ -30,3 +30,5 @@ require (
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
golang.zx2c4.com/wireguard/windows v0.5.3 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect
) )
replace github.com/fosrl/newt => ../newt

2
go.sum
View File

@@ -1,7 +1,5 @@
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/fosrl/newt v1.8.0 h1:wIRCO2shhCpkFzsbNbb4g2LC7mPzIpp2ialNveBMJy4=
github.com/fosrl/newt v1.8.0/go.mod h1:pol958CEs0nQmo/35Ltv0CGksheIKCS2hoNvdTVLEcI=
github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8= github.com/godbus/dbus/v5 v5.2.0 h1:3WexO+U+yg9T70v9FdHr9kCxYlazaAXUhx2VMkbfax8=
github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c= github.com/godbus/dbus/v5 v5.2.0/go.mod h1:3AAv2+hPq5rdnr5txxxRwiGjPXamgoIHgz9FPBfOp3c=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=

View File

@@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
MiddleDev: o.middleDev, MiddleDev: o.middleDev,
LocalIP: interfaceIP, LocalIP: interfaceIP,
SharedBind: o.sharedBind, SharedBind: o.sharedBind,
WSClient: o.olmClient, WSClient: o.websocket,
APIServer: o.apiServer, APIServer: o.apiServer,
}) })

View File

@@ -186,12 +186,12 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
} }
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry // Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId, "siteId": handshakeData.SiteId,
}, 1*time.Second) }, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
} }

View File

@@ -42,9 +42,14 @@ type Olm struct {
dnsProxy *dns.DNSProxy dnsProxy *dns.DNSProxy
apiServer *api.API apiServer *api.API
olmClient *websocket.Client websocket *websocket.Client
holePunchManager *holepunch.Manager holePunchManager *holepunch.Manager
peerManager *peers.PeerManager peerManager *peers.PeerManager
// Power mode management
currentPowerMode string
powerModeMu sync.Mutex
wakeUpTimer *time.Timer
wakeUpDebounce time.Duration
olmCtx context.Context olmCtx context.Context
tunnelCancel context.CancelFunc tunnelCancel context.CancelFunc
@@ -58,10 +63,11 @@ type Olm struct {
metaMu sync.Mutex metaMu sync.Mutex
stopRegister func() stopRegister func()
stopPeerSend func()
updateRegister func(newData any) updateRegister func(newData any)
stopPing chan struct{} stopServerPing func()
stopPeerSend func()
} }
// initTunnelInfo creates the shared UDP socket and holepunch manager. // initTunnelInfo creates the shared UDP socket and holepunch manager.
@@ -134,6 +140,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logFile = file logFile = file
} }
if config.WakeUpDebounce == 0 {
config.WakeUpDebounce = 3 * time.Second
}
logger.Debug("Checking permissions for native interface") logger.Debug("Checking permissions for native interface")
err := permissions.CheckNativeInterfacePermissions() err := permissions.CheckNativeInterfacePermissions()
if err != nil { if err != nil {
@@ -285,9 +295,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
tunnelCtx, cancel := context.WithCancel(o.olmCtx) tunnelCtx, cancel := context.WithCancel(o.olmCtx)
o.tunnelCancel = cancel o.tunnelCancel = cancel
// Recreate channels for this tunnel session
o.stopPing = make(chan struct{})
var ( var (
id = config.ID id = config.ID
secret = config.Secret secret = config.Secret
@@ -343,6 +350,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
o.apiServer.SetConnectionStatus(true) o.apiServer.SetConnectionStatus(true)
// restart the ping if we need to
if o.stopServerPing == nil {
o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{
"timestamp": time.Now().Unix(),
"userToken": olmClient.GetConfig().UserToken,
}, 30*time.Second, -1) // -1 means dont time out with the max attempts
}
if o.connected { if o.connected {
logger.Debug("Already connected, skipping registration") logger.Debug("Already connected, skipping registration")
return nil return nil
@@ -364,7 +379,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken, "userToken": userToken,
"fingerprint": o.fingerprint, "fingerprint": o.fingerprint,
"postures": o.postures, "postures": o.postures,
}, 1*time.Second) }, 1*time.Second, 10)
// Invoke onRegistered callback if configured // Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil { if o.olmConfig.OnRegistered != nil {
@@ -372,8 +387,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
} }
} }
go o.keepSendingPing(olmClient)
return nil return nil
}) })
@@ -446,7 +459,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
} }
defer func() { _ = olmClient.Close() }() defer func() { _ = olmClient.Close() }()
o.olmClient = olmClient o.websocket = olmClient
// Wait for context cancellation // Wait for context cancellation
<-tunnelCtx.Done() <-tunnelCtx.Done()
@@ -465,9 +478,9 @@ func (o *Olm) Close() {
o.holePunchManager = nil o.holePunchManager = nil
} }
if o.stopPing != nil { if o.stopServerPing != nil {
close(o.stopPing) o.stopServerPing()
o.stopPing = nil o.stopServerPing = nil
} }
if o.stopRegister != nil { if o.stopRegister != nil {
@@ -545,9 +558,9 @@ func (o *Olm) StopTunnel() error {
} }
// Close the websocket connection // Close the websocket connection
if o.olmClient != nil { if o.websocket != nil {
_ = o.olmClient.Close() _ = o.websocket.Close()
o.olmClient = nil o.websocket = nil
} }
o.Close() o.Close()
@@ -625,3 +638,157 @@ func (o *Olm) SetPostures(data map[string]any) {
o.postures = data o.postures = data
} }
// SetPowerMode switches between normal and low power modes
// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes
// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored
// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate
func (o *Olm) SetPowerMode(mode string) error {
// Validate mode
if mode != "normal" && mode != "low" {
return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode)
}
o.powerModeMu.Lock()
defer o.powerModeMu.Unlock()
// If already in the requested mode, return early
if o.currentPowerMode == mode {
// Cancel any pending wake-up timer if we're already in normal mode
if mode == "normal" && o.wakeUpTimer != nil {
o.wakeUpTimer.Stop()
o.wakeUpTimer = nil
}
logger.Debug("Already in %s power mode", mode)
return nil
}
if mode == "low" {
// Low Power Mode: Cancel any pending wake-up and immediately go to sleep
// Cancel pending wake-up timer if any
if o.wakeUpTimer != nil {
logger.Debug("Cancelling pending wake-up timer")
o.wakeUpTimer.Stop()
o.wakeUpTimer = nil
}
logger.Info("Switching to low power mode")
if o.websocket != nil {
logger.Info("Disconnecting websocket for low power mode")
if err := o.websocket.Disconnect(); err != nil {
logger.Error("Error disconnecting websocket: %v", err)
}
}
lowPowerInterval := 10 * time.Minute
if o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil {
peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval)
peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval)
logger.Info("Set monitoring intervals to 10 minutes for low power mode")
}
o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable
}
if o.holePunchManager != nil {
o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval)
}
o.currentPowerMode = "low"
logger.Info("Switched to low power mode")
} else {
// Normal Power Mode: Start debounce timer before actually waking up
// If there's already a pending wake-up timer, don't start another
if o.wakeUpTimer != nil {
logger.Debug("Wake-up already pending, ignoring duplicate request")
return nil
}
logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce)
o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() {
o.powerModeMu.Lock()
defer o.powerModeMu.Unlock()
// Clear the timer reference
o.wakeUpTimer = nil
// Double-check we're still in low power mode (could have changed)
if o.currentPowerMode == "normal" {
logger.Debug("Already in normal mode after debounce, skipping wake-up")
return
}
logger.Info("Debounce complete, switching to normal power mode")
logger.Info("Reconnecting websocket for normal power mode")
if o.websocket != nil {
if err := o.websocket.Connect(); err != nil {
logger.Error("Failed to reconnect websocket: %v", err)
return
}
}
// Restore intervals and reconnect websocket
if o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil {
peerMonitor.ResetPeerHolepunchInterval()
peerMonitor.ResetPeerInterval()
}
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
}
if o.holePunchManager != nil {
o.holePunchManager.ResetServerHolepunchInterval()
}
o.currentPowerMode = "normal"
logger.Info("Switched to normal power mode")
})
}
return nil
}
func (o *Olm) AddDevice(fd uint32) error {
if o.middleDev == nil {
return fmt.Errorf("middle device is not initialized")
}
if o.tunnelConfig.MTU == 0 {
return fmt.Errorf("tunnel MTU is not set")
}
tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU)
if err != nil {
return fmt.Errorf("failed to create TUN device from fd: %v", err)
}
// Update interface name if available
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
o.tunnelConfig.InterfaceName = realInterfaceName
}
// Replace the existing TUN device in the middle device with the new one
o.middleDev.AddDevice(tdev)
logger.Info("Added device from file descriptor %d", fd)
return nil
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}

View File

@@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint { if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId) logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
_ = o.holePunchManager.TriggerHolePunch() _ = o.holePunchManager.TriggerHolePunch()
o.holePunchManager.ResetInterval() o.holePunchManager.ResetServerHolepunchInterval()
} }
logger.Info("Successfully updated peer for site %d", updateData.SiteId) logger.Info("Successfully updated peer for site %d", updateData.SiteId)

View File

@@ -1,57 +0,0 @@
package olm
import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
)
func (o *Olm) sendPing(olm *websocket.Client) error {
err := olm.SendMessage("olm/ping", map[string]any{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
"fingerprint": o.fingerprint,
"postures": o.postures,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
}
logger.Debug("Sent ping message")
return nil
}
func (o *Olm) keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := o.sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
}
// Set up ticker for one minute intervals
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-o.stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := o.sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
}
}
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}

View File

@@ -24,6 +24,8 @@ type OlmConfig struct {
Version string Version string
Agent string Agent string
WakeUpDebounce time.Duration
// Debugging // Debugging
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")

View File

@@ -50,6 +50,8 @@ type PeerManager struct {
// key is the CIDR string, value is a set of siteIds that want this IP // key is the CIDR string, value is a set of siteIds that want this IP
allowedIPClaims map[string]map[int]bool allowedIPClaims map[string]map[int]bool
APIServer *api.API APIServer *api.API
PersistentKeepalive int
} }
// NewPeerManager creates a new PeerManager with an internal PeerMonitor // NewPeerManager creates a new PeerManager with an internal PeerMonitor
@@ -84,6 +86,13 @@ func (pm *PeerManager) GetPeer(siteId int) (SiteConfig, bool) {
return peer, ok return peer, ok
} }
// GetPeerMonitor returns the internal peer monitor instance
func (pm *PeerManager) GetPeerMonitor() *monitor.PeerMonitor {
pm.mu.RLock()
defer pm.mu.RUnlock()
return pm.peerMonitor
}
func (pm *PeerManager) GetAllPeers() []SiteConfig { func (pm *PeerManager) GetAllPeers() []SiteConfig {
pm.mu.RLock() pm.mu.RLock()
defer pm.mu.RUnlock() defer pm.mu.RUnlock()
@@ -120,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
wgConfig := siteConfig wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err return err
} }
@@ -159,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error {
return nil return nil
} }
// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once
// without recreating them. Returns a map of siteId to error for any peers that failed to update.
func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error {
pm.mu.RLock()
defer pm.mu.RUnlock()
pm.PersistentKeepalive = interval
errors := make(map[int]error)
for siteId, peer := range pm.peers {
err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval)
if err != nil {
errors[siteId] = err
}
}
if len(errors) == 0 {
return nil
}
return errors
}
func (pm *PeerManager) RemovePeer(siteId int) error { func (pm *PeerManager) RemovePeer(siteId int) error {
pm.mu.Lock() pm.mu.Lock()
defer pm.mu.Unlock() defer pm.mu.Unlock()
@@ -238,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error {
ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
wgConfig := promotedPeer wgConfig := promotedPeer
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
} }
} }
@@ -314,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
wgConfig := siteConfig wgConfig := siteConfig
wgConfig.AllowedIps = ownedIPs wgConfig.AllowedIps = ownedIPs
if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil {
return err return err
} }
@@ -324,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error {
promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId)
promotedWgConfig := promotedPeer promotedWgConfig := promotedPeer
promotedWgConfig.AllowedIps = promotedOwnedIPs promotedWgConfig.AllowedIps = promotedOwnedIPs
if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil {
logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err)
} }
} }

View File

@@ -31,7 +31,6 @@ type PeerMonitor struct {
monitors map[int]*Client monitors map[int]*Client
mutex sync.Mutex mutex sync.Mutex
running bool running bool
interval time.Duration
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
wsClient *websocket.Client wsClient *websocket.Client
@@ -50,11 +49,11 @@ type PeerMonitor struct {
// Holepunch testing fields // Holepunch testing fields
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
holepunchTester *holepunch.HolepunchTester holepunchTester *holepunch.HolepunchTester
holepunchInterval time.Duration
holepunchTimeout time.Duration holepunchTimeout time.Duration
holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing
holepunchStatus map[int]bool // siteID -> connected status holepunchStatus map[int]bool // siteID -> connected status
holepunchStopChan chan struct{} holepunchStopChan chan struct{}
holepunchUpdateChan chan struct{}
// Relay tracking fields // Relay tracking fields
relayedPeers map[int]bool // siteID -> whether the peer is currently relayed relayedPeers map[int]bool // siteID -> whether the peer is currently relayed
@@ -62,6 +61,8 @@ type PeerMonitor struct {
holepunchFailures map[int]int // siteID -> consecutive failure count holepunchFailures map[int]int // siteID -> consecutive failure count
// Exponential backoff fields for holepunch monitor // Exponential backoff fields for holepunch monitor
defaultHolepunchMinInterval time.Duration // Minimum interval (initial)
defaultHolepunchMaxInterval time.Duration
holepunchMinInterval time.Duration // Minimum interval (initial) holepunchMinInterval time.Duration // Minimum interval (initial)
holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) holepunchMaxInterval time.Duration // Maximum interval (cap for backoff)
holepunchBackoffMultiplier float64 // Multiplier for each stable check holepunchBackoffMultiplier float64 // Multiplier for each stable check
@@ -85,7 +86,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
pm := &PeerMonitor{ pm := &PeerMonitor{
monitors: make(map[int]*Client), monitors: make(map[int]*Client),
interval: 2 * time.Second, // Default check interval (faster)
timeout: 3 * time.Second, timeout: 3 * time.Second,
maxAttempts: 3, maxAttempts: 3,
wsClient: wsClient, wsClient: wsClient,
@@ -95,7 +95,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
nsCtx: ctx, nsCtx: ctx,
nsCancel: cancel, nsCancel: cancel,
sharedBind: sharedBind, sharedBind: sharedBind,
holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds
holepunchTimeout: 2 * time.Second, // Faster timeout holepunchTimeout: 2 * time.Second, // Faster timeout
holepunchEndpoints: make(map[int]string), holepunchEndpoints: make(map[int]string),
holepunchStatus: make(map[int]bool), holepunchStatus: make(map[int]bool),
@@ -109,11 +108,14 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
apiServer: apiServer, apiServer: apiServer,
wgConnectionStatus: make(map[int]bool), wgConnectionStatus: make(map[int]bool),
// Exponential backoff settings for holepunch monitor // Exponential backoff settings for holepunch monitor
defaultHolepunchMinInterval: 2 * time.Second,
defaultHolepunchMaxInterval: 30 * time.Second,
holepunchMinInterval: 2 * time.Second, holepunchMinInterval: 2 * time.Second,
holepunchMaxInterval: 30 * time.Second, holepunchMaxInterval: 30 * time.Second,
holepunchBackoffMultiplier: 1.5, holepunchBackoffMultiplier: 1.5,
holepunchStableCount: make(map[int]int), holepunchStableCount: make(map[int]int),
holepunchCurrentInterval: 2 * time.Second, holepunchCurrentInterval: 2 * time.Second,
holepunchUpdateChan: make(chan struct{}, 1),
} }
if err := pm.initNetstack(); err != nil { if err := pm.initNetstack(); err != nil {
@@ -129,41 +131,75 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe
} }
// SetInterval changes how frequently peers are checked // SetInterval changes how frequently peers are checked
func (pm *PeerMonitor) SetInterval(interval time.Duration) { func (pm *PeerMonitor) SetPeerInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.interval = interval
// Update interval for all existing monitors // Update interval for all existing monitors
for _, client := range pm.monitors { for _, client := range pm.monitors {
client.SetPacketInterval(interval) client.SetPacketInterval(minInterval, maxInterval)
}
} }
// SetTimeout changes the timeout for waiting for responses logger.Info("Set peer monitor interval to min: %s, max: %s", minInterval, maxInterval)
func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { }
func (pm *PeerMonitor) ResetPeerInterval() {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.timeout = timeout // Update interval for all existing monitors
// Update timeout for all existing monitors
for _, client := range pm.monitors { for _, client := range pm.monitors {
client.SetTimeout(timeout) client.ResetPacketInterval()
} }
} }
// SetMaxAttempts changes the maximum number of attempts for TestConnection // SetPeerHolepunchInterval sets both the minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) SetMaxAttempts(attempts int) { func (pm *PeerMonitor) SetPeerHolepunchInterval(minInterval, maxInterval time.Duration) {
pm.mutex.Lock()
pm.holepunchMinInterval = minInterval
pm.holepunchMaxInterval = maxInterval
// Reset current interval to the new minimum
pm.holepunchCurrentInterval = minInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Set holepunch interval to min: %s, max: %s", minInterval, maxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
}
}
// GetPeerHolepunchIntervals returns the current minimum and maximum intervals for holepunch monitoring
func (pm *PeerMonitor) GetPeerHolepunchIntervals() (minInterval, maxInterval time.Duration) {
pm.mutex.Lock() pm.mutex.Lock()
defer pm.mutex.Unlock() defer pm.mutex.Unlock()
pm.maxAttempts = attempts return pm.holepunchMinInterval, pm.holepunchMaxInterval
}
// Update max attempts for all existing monitors func (pm *PeerMonitor) ResetPeerHolepunchInterval() {
for _, client := range pm.monitors { pm.mutex.Lock()
client.SetMaxAttempts(attempts) pm.holepunchMinInterval = pm.defaultHolepunchMinInterval
pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval
pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval
updateChan := pm.holepunchUpdateChan
pm.mutex.Unlock()
logger.Info("Reset holepunch interval to defaults: min=%v, max=%v", pm.defaultHolepunchMinInterval, pm.defaultHolepunchMaxInterval)
// Signal the goroutine to apply the new interval if running
if updateChan != nil {
select {
case updateChan <- struct{}{}:
default:
// Channel full or closed, skip
}
} }
} }
@@ -182,11 +218,6 @@ func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint st
return err return err
} }
client.SetPacketInterval(pm.interval)
client.SetTimeout(pm.timeout)
client.SetMaxAttempts(pm.maxAttempts)
client.SetMaxInterval(30 * time.Second) // Allow backoff up to 30 seconds when stable
pm.monitors[siteID] = client pm.monitors[siteID] = client
pm.holepunchEndpoints[siteID] = holepunchEndpoint pm.holepunchEndpoints[siteID] = holepunchEndpoint
@@ -497,6 +528,15 @@ func (pm *PeerMonitor) runHolepunchMonitor() {
select { select {
case <-pm.holepunchStopChan: case <-pm.holepunchStopChan:
return return
case <-pm.holepunchUpdateChan:
// Interval settings changed, reset to minimum
pm.mutex.Lock()
pm.holepunchCurrentInterval = pm.holepunchMinInterval
currentInterval := pm.holepunchCurrentInterval
pm.mutex.Unlock()
timer.Reset(currentInterval)
logger.Debug("Holepunch monitor interval updated, reset to %v", currentInterval)
case <-timer.C: case <-timer.C:
anyStatusChanged := pm.checkHolepunchEndpoints() anyStatusChanged := pm.checkHolepunchEndpoints()
@@ -540,7 +580,7 @@ func (pm *PeerMonitor) checkHolepunchEndpoints() bool {
anyStatusChanged := false anyStatusChanged := false
for siteID, endpoint := range endpoints { for siteID, endpoint := range endpoints {
// logger.Debug("Testing holepunch endpoint for site %d: %s", siteID, endpoint) logger.Debug("holepunchTester: testing endpoint for site %d: %s", siteID, endpoint)
result := pm.holepunchTester.TestEndpoint(endpoint, timeout) result := pm.holepunchTester.TestEndpoint(endpoint, timeout)
pm.mutex.Lock() pm.mutex.Lock()
@@ -689,55 +729,55 @@ func (pm *PeerMonitor) Close() {
logger.Debug("PeerMonitor: Cleanup complete") logger.Debug("PeerMonitor: Cleanup complete")
} }
// TestPeer tests connectivity to a specific peer // // TestPeer tests connectivity to a specific peer
func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) { // func (pm *PeerMonitor) TestPeer(siteID int) (bool, time.Duration, error) {
pm.mutex.Lock() // pm.mutex.Lock()
client, exists := pm.monitors[siteID] // client, exists := pm.monitors[siteID]
pm.mutex.Unlock() // pm.mutex.Unlock()
if !exists { // if !exists {
return false, 0, fmt.Errorf("peer with siteID %d not found", siteID) // return false, 0, fmt.Errorf("peer with siteID %d not found", siteID)
} // }
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
defer cancel() // defer cancel()
connected, rtt := client.TestConnection(ctx) // connected, rtt := client.TestPeerConnection(ctx)
return connected, rtt, nil // return connected, rtt, nil
} // }
// TestAllPeers tests connectivity to all peers // // TestAllPeers tests connectivity to all peers
func (pm *PeerMonitor) TestAllPeers() map[int]struct { // func (pm *PeerMonitor) TestAllPeers() map[int]struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
} { // } {
pm.mutex.Lock() // pm.mutex.Lock()
peers := make(map[int]*Client, len(pm.monitors)) // peers := make(map[int]*Client, len(pm.monitors))
for siteID, client := range pm.monitors { // for siteID, client := range pm.monitors {
peers[siteID] = client // peers[siteID] = client
} // }
pm.mutex.Unlock() // pm.mutex.Unlock()
results := make(map[int]struct { // results := make(map[int]struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
}) // })
for siteID, client := range peers { // for siteID, client := range peers {
ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts)) // ctx, cancel := context.WithTimeout(context.Background(), pm.timeout*time.Duration(pm.maxAttempts))
connected, rtt := client.TestConnection(ctx) // connected, rtt := client.TestPeerConnection(ctx)
cancel() // cancel()
results[siteID] = struct { // results[siteID] = struct {
Connected bool // Connected bool
RTT time.Duration // RTT time.Duration
}{ // }{
Connected: connected, // Connected: connected,
RTT: rtt, // RTT: rtt,
} // }
} // }
return results // return results
} // }
// initNetstack initializes the gvisor netstack // initNetstack initializes the gvisor netstack
func (pm *PeerMonitor) initNetstack() error { func (pm *PeerMonitor) initNetstack() error {

View File

@@ -32,12 +32,15 @@ type Client struct {
monitorLock sync.Mutex monitorLock sync.Mutex
connLock sync.Mutex // Protects connection operations connLock sync.Mutex // Protects connection operations
shutdownCh chan struct{} shutdownCh chan struct{}
updateCh chan struct{}
packetInterval time.Duration packetInterval time.Duration
timeout time.Duration timeout time.Duration
maxAttempts int maxAttempts int
dialer Dialer dialer Dialer
// Exponential backoff fields // Exponential backoff fields
defaultMinInterval time.Duration // Default minimum interval (initial)
defaultMaxInterval time.Duration // Default maximum interval (cap for backoff)
minInterval time.Duration // Minimum interval (initial) minInterval time.Duration // Minimum interval (initial)
maxInterval time.Duration // Maximum interval (cap for backoff) maxInterval time.Duration // Maximum interval (cap for backoff)
backoffMultiplier float64 // Multiplier for each stable check backoffMultiplier float64 // Multiplier for each stable check
@@ -58,7 +61,10 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
return &Client{ return &Client{
serverAddr: serverAddr, serverAddr: serverAddr,
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
updateCh: make(chan struct{}, 1),
packetInterval: 2 * time.Second, packetInterval: 2 * time.Second,
defaultMinInterval: 2 * time.Second,
defaultMaxInterval: 30 * time.Second,
minInterval: 2 * time.Second, minInterval: 2 * time.Second,
maxInterval: 30 * time.Second, maxInterval: 30 * time.Second,
backoffMultiplier: 1.5, backoffMultiplier: 1.5,
@@ -70,29 +76,42 @@ func NewClient(serverAddr string, dialer Dialer) (*Client, error) {
} }
// SetPacketInterval changes how frequently packets are sent in monitor mode // SetPacketInterval changes how frequently packets are sent in monitor mode
func (c *Client) SetPacketInterval(interval time.Duration) { func (c *Client) SetPacketInterval(minInterval, maxInterval time.Duration) {
c.packetInterval = interval c.monitorLock.Lock()
c.minInterval = interval c.packetInterval = minInterval
c.minInterval = minInterval
c.maxInterval = maxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// Signal the goroutine to apply the new interval if running
if monitorRunning && updateCh != nil {
select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
}
}
} }
// SetTimeout changes the timeout for waiting for responses func (c *Client) ResetPacketInterval() {
func (c *Client) SetTimeout(timeout time.Duration) { c.monitorLock.Lock()
c.timeout = timeout c.packetInterval = c.defaultMinInterval
} c.minInterval = c.defaultMinInterval
c.maxInterval = c.defaultMaxInterval
updateCh := c.updateCh
monitorRunning := c.monitorRunning
c.monitorLock.Unlock()
// SetMaxAttempts changes the maximum number of attempts for TestConnection // Signal the goroutine to apply the new interval if running
func (c *Client) SetMaxAttempts(attempts int) { if monitorRunning && updateCh != nil {
c.maxAttempts = attempts select {
case updateCh <- struct{}{}:
default:
// Channel full or closed, skip
} }
// SetMaxInterval sets the maximum backoff interval
func (c *Client) SetMaxInterval(interval time.Duration) {
c.maxInterval = interval
} }
// SetBackoffMultiplier sets the multiplier for exponential backoff
func (c *Client) SetBackoffMultiplier(multiplier float64) {
c.backoffMultiplier = multiplier
} }
// UpdateServerAddr updates the server address and resets the connection // UpdateServerAddr updates the server address and resets the connection
@@ -146,9 +165,10 @@ func (c *Client) ensureConnection() error {
return nil return nil
} }
// TestConnection checks if the connection to the server is working // TestPeerConnection checks if the connection to the server is working
// Returns true if connected, false otherwise // Returns true if connected, false otherwise
func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) { func (c *Client) TestPeerConnection(ctx context.Context) (bool, time.Duration) {
logger.Debug("wgtester: testing connection to peer %s", c.serverAddr)
if err := c.ensureConnection(); err != nil { if err := c.ensureConnection(); err != nil {
logger.Warn("Failed to ensure connection: %v", err) logger.Warn("Failed to ensure connection: %v", err)
return false, 0 return false, 0
@@ -232,7 +252,7 @@ func (c *Client) TestConnection(ctx context.Context) (bool, time.Duration) {
func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) { func (c *Client) TestConnectionWithTimeout(timeout time.Duration) (bool, time.Duration) {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel() defer cancel()
return c.TestConnection(ctx) return c.TestPeerConnection(ctx)
} }
// MonitorCallback is the function type for connection status change callbacks // MonitorCallback is the function type for connection status change callbacks
@@ -269,9 +289,20 @@ func (c *Client) StartMonitor(callback MonitorCallback) error {
select { select {
case <-c.shutdownCh: case <-c.shutdownCh:
return return
case <-c.updateCh:
// Interval settings changed, reset to minimum
c.monitorLock.Lock()
currentInterval = c.minInterval
c.monitorLock.Unlock()
// Reset backoff state
stableCount = 0
timer.Reset(currentInterval)
logger.Debug("Packet interval updated, reset to %v", currentInterval)
case <-timer.C: case <-timer.C:
ctx, cancel := context.WithTimeout(context.Background(), c.timeout) ctx, cancel := context.WithTimeout(context.Background(), c.timeout)
connected, rtt := c.TestConnection(ctx) connected, rtt := c.TestPeerConnection(ctx)
cancel() cancel()
statusChanged := connected != lastConnected statusChanged := connected != lastConnected

View File

@@ -11,7 +11,7 @@ import (
) )
// ConfigurePeer sets up or updates a peer within the WireGuard device // ConfigurePeer sets up or updates a peer within the WireGuard device
func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error {
var endpoint string var endpoint string
if relay && siteConfig.RelayEndpoint != "" { if relay && siteConfig.RelayEndpoint != "" {
endpoint = formatEndpoint(siteConfig.RelayEndpoint) endpoint = formatEndpoint(siteConfig.RelayEndpoint)
@@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes
} }
configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost))
configBuilder.WriteString("persistent_keepalive_interval=5\n") configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive))
config := configBuilder.String() config := configBuilder.String()
logger.Debug("Configuring peer with config: %s", config) logger.Debug("Configuring peer with config: %s", config)
@@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [
return nil return nil
} }
// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it
func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error {
var configBuilder strings.Builder
configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey)))
configBuilder.WriteString("update_only=true\n")
configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval))
config := configBuilder.String()
logger.Debug("Updating persistent keepalive for peer with config: %s", config)
err := dev.IpcSet(config)
if err != nil {
return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err)
}
return nil
}
func formatEndpoint(endpoint string) string { func formatEndpoint(endpoint string) string {
if strings.Contains(endpoint, ":") { if strings.Contains(endpoint, ":") {
return endpoint return endpoint

View File

@@ -77,6 +77,7 @@ type Client struct {
handlersMux sync.RWMutex handlersMux sync.RWMutex
reconnectInterval time.Duration reconnectInterval time.Duration
isConnected bool isConnected bool
isDisconnected bool // Flag to track if client is intentionally disconnected
reconnectMux sync.RWMutex reconnectMux sync.RWMutex
pingInterval time.Duration pingInterval time.Duration
pingTimeout time.Duration pingTimeout time.Duration
@@ -87,6 +88,10 @@ type Client struct {
clientType string // Type of client (e.g., "newt", "olm") clientType string // Type of client (e.g., "newt", "olm")
tlsConfig TLSConfig tlsConfig TLSConfig
configNeedsSave bool // Flag to track if config needs to be saved configNeedsSave bool // Flag to track if config needs to be saved
token string // Cached authentication token
exitNodes []ExitNode // Cached exit nodes from token response
tokenMux sync.RWMutex // Protects token and exitNodes
forceNewToken bool // Flag to force fetching a new token on next connection
} }
type ClientOption func(*Client) type ClientOption func(*Client)
@@ -173,6 +178,9 @@ func (c *Client) GetConfig() *Config {
// Connect establishes the WebSocket connection // Connect establishes the WebSocket connection
func (c *Client) Connect() error { func (c *Client) Connect() error {
if c.isDisconnected {
c.isDisconnected = false
}
go c.connectWithRetry() go c.connectWithRetry()
return nil return nil
} }
@@ -205,9 +213,25 @@ func (c *Client) Close() error {
return nil return nil
} }
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
func (c *Client) Disconnect() error {
c.isDisconnected = true
c.setConnected(false)
if c.conn != nil {
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
// SendMessage sends a message through the WebSocket connection // SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error { func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil { if c.isDisconnected || c.conn == nil {
return fmt.Errorf("not connected") return fmt.Errorf("not connected")
} }
@@ -216,14 +240,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
Data: data, Data: data,
} }
logger.Debug("Sending message: %s, data: %+v", messageType, data) logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
c.writeMux.Lock() c.writeMux.Lock()
defer c.writeMux.Unlock() defer c.writeMux.Unlock()
return c.conn.WriteJSON(msg) return c.conn.WriteJSON(msg)
} }
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{}) stopChan := make(chan struct{})
updateChan := make(chan interface{}) updateChan := make(chan interface{})
var dataMux sync.Mutex var dataMux sync.Mutex
@@ -231,30 +255,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
go func() { go func() {
count := 0 count := 0
maxAttempts := 10
err := c.SendMessage(messageType, currentData) // Send immediately send := func() {
if c.isDisconnected || c.conn == nil {
return
}
err := c.SendMessage(messageType, currentData)
if err != nil { if err != nil {
logger.Error("Failed to send initial message: %v", err) logger.Error("websocket: Failed to send message: %v", err)
} }
count++ count++
}
send() // Send immediately
ticker := time.NewTicker(interval) ticker := time.NewTicker(interval)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if count >= maxAttempts { if maxAttempts != -1 && count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return return
} }
dataMux.Lock() dataMux.Lock()
err = c.SendMessage(messageType, currentData) send()
dataMux.Unlock() dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan: case newData := <-updateChan:
dataMux.Lock() dataMux.Lock()
// Merge newData into currentData if both are maps // Merge newData into currentData if both are maps
@@ -277,6 +303,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
case <-stopChan: case <-stopChan:
return return
} }
// Suspend sending if disconnected
for c.isDisconnected {
select {
case <-stopChan:
return
case <-time.After(500 * time.Millisecond):
}
}
} }
}() }()
return func() { return func() {
@@ -323,7 +357,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
tlsConfig = &tls.Config{} tlsConfig = &tls.Config{}
} }
tlsConfig.InsecureSkipVerify = true tlsConfig.InsecureSkipVerify = true
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
tokenData := map[string]interface{}{ tokenData := map[string]interface{}{
@@ -352,7 +386,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
req.Header.Set("X-CSRF-Token", "x-csrf-protection") req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// print out the request for debugging // print out the request for debugging
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData)) logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
// Make the request // Make the request
client := &http.Client{} client := &http.Client{}
@@ -369,7 +403,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
// Return AuthError for 401/403 status codes // Return AuthError for 401/403 status codes
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
@@ -385,7 +419,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
var tokenResp TokenResponse var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
logger.Error("Failed to decode token response.") logger.Error("websocket: Failed to decode token response.")
return "", nil, fmt.Errorf("failed to decode token response: %w", err) return "", nil, fmt.Errorf("failed to decode token response: %w", err)
} }
@@ -397,7 +431,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
return "", nil, fmt.Errorf("received empty token from server") return "", nil, fmt.Errorf("received empty token from server")
} }
logger.Debug("Received token: %s", tokenResp.Data.Token) logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
} }
@@ -412,7 +446,7 @@ func (c *Client) connectWithRetry() {
if err != nil { if err != nil {
// Check if this is an auth error (401/403) // Check if this is an auth error (401/403)
if authErr, ok := err.(*AuthError); ok { if authErr, ok := err.(*AuthError); ok {
logger.Error("Authentication failed: %v. Terminating tunnel and retrying...", authErr) logger.Error("websocket: Authentication failed: %v. Terminating tunnel and retrying...", authErr)
// Trigger auth error callback if set (this should terminate the tunnel) // Trigger auth error callback if set (this should terminate the tunnel)
if c.onAuthError != nil { if c.onAuthError != nil {
c.onAuthError(authErr.StatusCode, authErr.Message) c.onAuthError(authErr.StatusCode, authErr.Message)
@@ -422,7 +456,7 @@ func (c *Client) connectWithRetry() {
continue continue
} }
// For other errors (5xx, network issues), continue retrying // For other errors (5xx, network issues), continue retrying
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval) time.Sleep(c.reconnectInterval)
continue continue
} }
@@ -432,15 +466,25 @@ func (c *Client) connectWithRetry() {
} }
func (c *Client) establishConnection() error { func (c *Client) establishConnection() error {
// Get token for authentication // Get token for authentication - reuse cached token unless forced to get new one
c.tokenMux.Lock()
needNewToken := c.token == "" || c.forceNewToken
if needNewToken {
token, exitNodes, err := c.getToken() token, exitNodes, err := c.getToken()
if err != nil { if err != nil {
c.tokenMux.Unlock()
return fmt.Errorf("failed to get token: %w", err) return fmt.Errorf("failed to get token: %w", err)
} }
c.token = token
c.exitNodes = exitNodes
c.forceNewToken = false
if c.onTokenUpdate != nil { if c.onTokenUpdate != nil {
c.onTokenUpdate(token, exitNodes) c.onTokenUpdate(token, exitNodes)
} }
}
token := c.token
c.tokenMux.Unlock()
// Parse the base URL to determine protocol and hostname // Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL) baseURL, err := url.Parse(c.baseURL)
@@ -475,7 +519,7 @@ func (c *Client) establishConnection() error {
// Use new TLS configuration method // Use new TLS configuration method
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
logger.Info("Setting up TLS configuration for WebSocket connection") logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
tlsConfig, err := c.setupTLS() tlsConfig, err := c.setupTLS()
if err != nil { if err != nil {
return fmt.Errorf("failed to setup TLS configuration: %w", err) return fmt.Errorf("failed to setup TLS configuration: %w", err)
@@ -489,11 +533,23 @@ func (c *Client) establishConnection() error {
dialer.TLSClientConfig = &tls.Config{} dialer.TLSClientConfig = &tls.Config{}
} }
dialer.TLSClientConfig.InsecureSkipVerify = true dialer.TLSClientConfig.InsecureSkipVerify = true
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
conn, _, err := dialer.Dial(u.String(), nil) conn, resp, err := dialer.Dial(u.String(), nil)
if err != nil { if err != nil {
// Check if this is an unauthorized error (401)
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
// Force getting a new token on next reconnect attempt
c.tokenMux.Lock()
c.forceNewToken = true
c.tokenMux.Unlock()
return &AuthError{
StatusCode: http.StatusUnauthorized,
Message: "WebSocket connection unauthorized",
}
}
return fmt.Errorf("failed to connect to WebSocket: %w", err) return fmt.Errorf("failed to connect to WebSocket: %w", err)
} }
@@ -507,7 +563,7 @@ func (c *Client) establishConnection() error {
if c.onConnect != nil { if c.onConnect != nil {
if err := c.onConnect(); err != nil { if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err) logger.Error("websocket: OnConnect callback failed: %v", err)
} }
} }
@@ -520,9 +576,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Handle new separate certificate configuration // Handle new separate certificate configuration
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
logger.Info("Loading separate certificate files for mTLS") logger.Info("websocket: Loading separate certificate files for mTLS")
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
// Load client certificate and key // Load client certificate and key
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
@@ -533,7 +589,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Load CA certificates for remote validation if specified // Load CA certificates for remote validation if specified
if len(c.tlsConfig.CAFiles) > 0 { if len(c.tlsConfig.CAFiles) > 0 {
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
for _, caFile := range c.tlsConfig.CAFiles { for _, caFile := range c.tlsConfig.CAFiles {
caCert, err := os.ReadFile(caFile) caCert, err := os.ReadFile(caFile)
@@ -559,13 +615,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
// Fallback to existing PKCS12 implementation for backward compatibility // Fallback to existing PKCS12 implementation for backward compatibility
if c.tlsConfig.PKCS12File != "" { if c.tlsConfig.PKCS12File != "" {
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
return c.setupPKCS12TLS() return c.setupPKCS12TLS()
} }
// Legacy fallback using config.TlsClientCert // Legacy fallback using config.TlsClientCert
if c.config.TlsClientCert != "" { if c.config.TlsClientCert != "" {
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
return loadClientCertificate(c.config.TlsClientCert) return loadClientCertificate(c.config.TlsClientCert)
} }
@@ -587,7 +643,7 @@ func (c *Client) pingMonitor() {
case <-c.done: case <-c.done:
return return
case <-ticker.C: case <-ticker.C:
if c.conn == nil { if c.isDisconnected || c.conn == nil {
return return
} }
c.writeMux.Lock() c.writeMux.Lock()
@@ -600,7 +656,7 @@ func (c *Client) pingMonitor() {
// Expected during shutdown // Expected during shutdown
return return
default: default:
logger.Error("Ping failed: %v", err) logger.Error("websocket: Ping failed: %v", err)
c.reconnect() c.reconnect()
return return
} }
@@ -633,18 +689,24 @@ func (c *Client) readPumpWithDisconnectDetection() {
var msg WSMessage var msg WSMessage
err := c.conn.ReadJSON(&msg) err := c.conn.ReadJSON(&msg)
if err != nil { if err != nil {
// Check if we're shutting down before logging error // Check if we're shutting down or explicitly disconnected before logging error
select { select {
case <-c.done: case <-c.done:
// Expected during shutdown, don't log as error // Expected during shutdown, don't log as error
logger.Debug("WebSocket connection closed during shutdown") logger.Debug("websocket: connection closed during shutdown")
return return
default: default:
// Check if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: connection closed: client was explicitly disconnected")
return
}
// Unexpected error during normal operation // Unexpected error during normal operation
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
logger.Error("WebSocket read error: %v", err) logger.Error("websocket: read error: %v", err)
} else { } else {
logger.Debug("WebSocket connection closed: %v", err) logger.Debug("websocket: connection closed: %v", err)
} }
return // triggers reconnect via defer return // triggers reconnect via defer
} }
@@ -666,6 +728,12 @@ func (c *Client) reconnect() {
c.conn = nil c.conn = nil
} }
// Don't reconnect if explicitly disconnected
if c.isDisconnected {
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
return
}
// Only reconnect if we're not shutting down // Only reconnect if we're not shutting down
select { select {
case <-c.done: case <-c.done:
@@ -683,7 +751,7 @@ func (c *Client) setConnected(status bool) {
// LoadClientCertificate Helper method to load client certificates (PKCS12 format) // LoadClientCertificate Helper method to load client certificates (PKCS12 format)
func loadClientCertificate(p12Path string) (*tls.Config, error) { func loadClientCertificate(p12Path string) (*tls.Config, error) {
logger.Info("Loading tls-client-cert %s", p12Path) logger.Info("websocket: Loading tls-client-cert %s", p12Path)
// Read the PKCS12 file // Read the PKCS12 file
p12Data, err := os.ReadFile(p12Path) p12Data, err := os.ReadFile(p12Path)
if err != nil { if err != nil {