diff --git a/olm/olm.go b/olm/olm.go index 63b53a7..6a0a26f 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -7,6 +7,7 @@ import ( "net/http" _ "net/http/pprof" "os" + "sync" "time" "github.com/fosrl/newt/bind" @@ -46,9 +47,9 @@ type Olm struct { peerManager *peers.PeerManager // Power mode management currentPowerMode string - originalPeerInterval time.Duration - originalHolepunchMinInterval time.Duration - originalHolepunchMaxInterval time.Duration + powerModeMu sync.Mutex + wakeUpTimer *time.Timer + wakeUpDebounce time.Duration olmCtx context.Context tunnelCancel context.CancelFunc @@ -133,6 +134,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.SetOutput(file) logFile = file } + + if config.WakeUpDebounce == 0 { + config.WakeUpDebounce = 3 * time.Second + } logger.Debug("Checking permissions for native interface") err := permissions.CheckNativeInterfacePermissions() @@ -589,22 +594,38 @@ func (o *Olm) SwitchOrg(orgID string) error { // SetPowerMode switches between normal and low power modes // In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes // In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored +// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate func (o *Olm) SetPowerMode(mode string) error { // Validate mode if mode != "normal" && mode != "low" { return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode) } + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + // If already in the requested mode, return early if o.currentPowerMode == mode { + // Cancel any pending wake-up timer if we're already in normal mode + if mode == "normal" && o.wakeUpTimer != nil { + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } logger.Debug("Already in %s power mode", mode) return nil } - logger.Info("Switching to %s power mode", mode) - if mode == "low" { - // Low Power Mode: Close websocket and reduce monitoring frequency + // Low Power Mode: Cancel any pending wake-up and immediately go to sleep + + // Cancel pending wake-up timer if any + if o.wakeUpTimer != nil { + logger.Debug("Cancelling pending wake-up timer") + o.wakeUpTimer.Stop() + o.wakeUpTimer = nil + } + + logger.Info("Switching to low power mode") if o.websocket != nil { logger.Info("Closing websocket connection for low power mode") @@ -613,14 +634,6 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.originalPeerInterval == 0 && o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - o.originalPeerInterval = 2 * time.Second - o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval = peerMonitor.GetHolepunchIntervals() - } - } - if o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { @@ -629,45 +642,61 @@ func (o *Olm) SetPowerMode(mode string) error { peerMonitor.SetHolepunchInterval(lowPowerInterval, lowPowerInterval) logger.Info("Set monitoring intervals to 10 minutes for low power mode") } + o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable } o.currentPowerMode = "low" logger.Info("Switched to low power mode") } else { - // Normal Power Mode: Restore intervals and reconnect websocket + // Normal Power Mode: Start debounce timer before actually waking up - if o.peerManager != nil { - peerMonitor := o.peerManager.GetPeerMonitor() - if peerMonitor != nil { - if o.originalPeerInterval == 0 { - o.originalPeerInterval = 2 * time.Second - } - peerMonitor.SetInterval(o.originalPeerInterval) - - if o.originalHolepunchMinInterval == 0 { - o.originalHolepunchMinInterval = 2 * time.Second - } - if o.originalHolepunchMaxInterval == 0 { - o.originalHolepunchMaxInterval = 30 * time.Second - } - peerMonitor.SetHolepunchInterval(o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval) - logger.Info("Restored monitoring intervals to normal (peer: %v, holepunch: %v-%v)", - o.originalPeerInterval, o.originalHolepunchMinInterval, o.originalHolepunchMaxInterval) - } + // If there's already a pending wake-up timer, don't start another + if o.wakeUpTimer != nil { + logger.Debug("Wake-up already pending, ignoring duplicate request") + return nil } - logger.Info("Reconnecting websocket for normal power mode") + logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce) - if o.websocket != nil { - if err := o.websocket.Connect(); err != nil { - logger.Error("Failed to reconnect websocket: %v", err) - return fmt.Errorf("failed to reconnect websocket: %w", err) + o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() { + o.powerModeMu.Lock() + defer o.powerModeMu.Unlock() + + // Clear the timer reference + o.wakeUpTimer = nil + + // Double-check we're still in low power mode (could have changed) + if o.currentPowerMode == "normal" { + logger.Debug("Already in normal mode after debounce, skipping wake-up") + return } - } - o.currentPowerMode = "normal" - logger.Info("Switched to normal power mode") + logger.Info("Debounce complete, switching to normal power mode") + + // Restore intervals and reconnect websocket + if o.peerManager != nil { + peerMonitor := o.peerManager.GetPeerMonitor() + if peerMonitor != nil { + peerMonitor.ResetHolepunchInterval() + peerMonitor.ResetInterval() + } + + o.peerManager.UpdateAllPeersPersistentKeepalive(5) + } + + logger.Info("Reconnecting websocket for normal power mode") + + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { + logger.Error("Failed to reconnect websocket: %v", err) + return + } + } + + o.currentPowerMode = "normal" + logger.Info("Switched to normal power mode") + }) } return nil diff --git a/olm/types.go b/olm/types.go index 77c0b5f..397eab9 100644 --- a/olm/types.go +++ b/olm/types.go @@ -23,6 +23,8 @@ type OlmConfig struct { SocketPath string Version string Agent string + + WakeUpDebounce time.Duration // Debugging PprofAddr string // Address to serve pprof on (e.g., "localhost:6060") diff --git a/peers/manager.go b/peers/manager.go index 56f3707..0566775 100644 --- a/peers/manager.go +++ b/peers/manager.go @@ -50,6 +50,8 @@ type PeerManager struct { // key is the CIDR string, value is a set of siteIds that want this IP allowedIPClaims map[string]map[int]bool APIServer *api.API + + PersistentKeepalive int } // NewPeerManager creates a new PeerManager with an internal PeerMonitor @@ -127,7 +129,7 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -166,6 +168,29 @@ func (pm *PeerManager) AddPeer(siteConfig SiteConfig) error { return nil } +// UpdateAllPeersPersistentKeepalive updates the persistent keepalive interval for all peers at once +// without recreating them. Returns a map of siteId to error for any peers that failed to update. +func (pm *PeerManager) UpdateAllPeersPersistentKeepalive(interval int) map[int]error { + pm.mu.RLock() + defer pm.mu.RUnlock() + + pm.PersistentKeepalive = interval + + errors := make(map[int]error) + + for siteId, peer := range pm.peers { + err := UpdatePersistentKeepalive(pm.device, peer.PublicKey, interval) + if err != nil { + errors[siteId] = err + } + } + + if len(errors) == 0 { + return nil + } + return errors +} + func (pm *PeerManager) RemovePeer(siteId int) error { pm.mu.Lock() defer pm.mu.Unlock() @@ -245,7 +270,7 @@ func (pm *PeerManager) RemovePeer(siteId int) error { ownedIPs := pm.getOwnedAllowedIPs(promotedPeerId) wgConfig := promotedPeer wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } @@ -321,7 +346,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { wgConfig := siteConfig wgConfig.AllowedIps = ownedIPs - if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId)); err != nil { + if err := ConfigurePeer(pm.device, wgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(siteConfig.SiteId), pm.PersistentKeepalive); err != nil { return err } @@ -331,7 +356,7 @@ func (pm *PeerManager) UpdatePeer(siteConfig SiteConfig) error { promotedOwnedIPs := pm.getOwnedAllowedIPs(promotedPeerId) promotedWgConfig := promotedPeer promotedWgConfig.AllowedIps = promotedOwnedIPs - if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId)); err != nil { + if err := ConfigurePeer(pm.device, promotedWgConfig, pm.privateKey, pm.peerMonitor.IsPeerRelayed(promotedPeerId), pm.PersistentKeepalive); err != nil { logger.Error("Failed to update promoted peer %d: %v", promotedPeerId, err) } } diff --git a/peers/monitor/monitor.go b/peers/monitor/monitor.go index 2bb0c80..3ac4b54 100644 --- a/peers/monitor/monitor.go +++ b/peers/monitor/monitor.go @@ -28,13 +28,14 @@ import ( // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { - monitors map[int]*Client - mutex sync.Mutex - running bool - interval time.Duration - timeout time.Duration - maxAttempts int - wsClient *websocket.Client + monitors map[int]*Client + mutex sync.Mutex + running bool + defaultInterval time.Duration + interval time.Duration + timeout time.Duration + maxAttempts int + wsClient *websocket.Client // Netstack fields middleDev *middleDevice.MiddleDevice @@ -50,7 +51,6 @@ type PeerMonitor struct { // Holepunch testing fields sharedBind *bind.SharedBind holepunchTester *holepunch.HolepunchTester - holepunchInterval time.Duration holepunchTimeout time.Duration holepunchEndpoints map[int]string // siteID -> endpoint for holepunch testing holepunchStatus map[int]bool // siteID -> connected status @@ -62,11 +62,13 @@ type PeerMonitor struct { holepunchFailures map[int]int // siteID -> consecutive failure count // Exponential backoff fields for holepunch monitor - holepunchMinInterval time.Duration // Minimum interval (initial) - holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) - holepunchBackoffMultiplier float64 // Multiplier for each stable check - holepunchStableCount map[int]int // siteID -> consecutive stable status count - holepunchCurrentInterval time.Duration // Current interval with backoff applied + defaultHolepunchMinInterval time.Duration // Minimum interval (initial) + defaultHolepunchMaxInterval time.Duration + holepunchMinInterval time.Duration // Minimum interval (initial) + holepunchMaxInterval time.Duration // Maximum interval (cap for backoff) + holepunchBackoffMultiplier float64 // Multiplier for each stable check + holepunchStableCount map[int]int // siteID -> consecutive stable status count + holepunchCurrentInterval time.Duration // Current interval with backoff applied // Rapid initial test fields rapidTestInterval time.Duration // interval between rapid test attempts @@ -85,6 +87,7 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe ctx, cancel := context.WithCancel(context.Background()) pm := &PeerMonitor{ monitors: make(map[int]*Client), + defaultInterval: 2 * time.Second, interval: 2 * time.Second, // Default check interval (faster) timeout: 3 * time.Second, maxAttempts: 3, @@ -95,7 +98,6 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe nsCtx: ctx, nsCancel: cancel, sharedBind: sharedBind, - holepunchInterval: 2 * time.Second, // Check holepunch every 2 seconds holepunchTimeout: 2 * time.Second, // Faster timeout holepunchEndpoints: make(map[int]string), holepunchStatus: make(map[int]bool), @@ -109,11 +111,13 @@ func NewPeerMonitor(wsClient *websocket.Client, middleDev *middleDevice.MiddleDe apiServer: apiServer, wgConnectionStatus: make(map[int]bool), // Exponential backoff settings for holepunch monitor - holepunchMinInterval: 2 * time.Second, - holepunchMaxInterval: 30 * time.Second, - holepunchBackoffMultiplier: 1.5, - holepunchStableCount: make(map[int]int), - holepunchCurrentInterval: 2 * time.Second, + defaultHolepunchMinInterval: 2 * time.Second, + defaultHolepunchMaxInterval: 30 * time.Second, + holepunchMinInterval: 2 * time.Second, + holepunchMaxInterval: 30 * time.Second, + holepunchBackoffMultiplier: 1.5, + holepunchStableCount: make(map[int]int), + holepunchCurrentInterval: 2 * time.Second, } if err := pm.initNetstack(); err != nil { @@ -141,6 +145,18 @@ func (pm *PeerMonitor) SetInterval(interval time.Duration) { } } +func (pm *PeerMonitor) ResetInterval() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.interval = pm.defaultInterval + + // Update interval for all existing monitors + for _, client := range pm.monitors { + client.SetPacketInterval(pm.defaultInterval) + } +} + // SetTimeout changes the timeout for waiting for responses func (pm *PeerMonitor) SetTimeout(timeout time.Duration) { pm.mutex.Lock() @@ -186,6 +202,15 @@ func (pm *PeerMonitor) GetHolepunchIntervals() (minInterval, maxInterval time.Du return pm.holepunchMinInterval, pm.holepunchMaxInterval } +func (pm *PeerMonitor) ResetHolepunchInterval() { + pm.mutex.Lock() + defer pm.mutex.Unlock() + + pm.holepunchMinInterval = pm.defaultHolepunchMinInterval + pm.holepunchMaxInterval = pm.defaultHolepunchMaxInterval + pm.holepunchCurrentInterval = pm.defaultHolepunchMinInterval +} + // AddPeer adds a new peer to monitor func (pm *PeerMonitor) AddPeer(siteID int, endpoint string, holepunchEndpoint string) error { pm.mutex.Lock() diff --git a/peers/peer.go b/peers/peer.go index 9370b9d..8211fa4 100644 --- a/peers/peer.go +++ b/peers/peer.go @@ -11,7 +11,7 @@ import ( ) // ConfigurePeer sets up or updates a peer within the WireGuard device -func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool) error { +func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes.Key, relay bool, persistentKeepalive int) error { var endpoint string if relay && siteConfig.RelayEndpoint != "" { endpoint = formatEndpoint(siteConfig.RelayEndpoint) @@ -61,7 +61,7 @@ func ConfigurePeer(dev *device.Device, siteConfig SiteConfig, privateKey wgtypes } configBuilder.WriteString(fmt.Sprintf("endpoint=%s\n", siteHost)) - configBuilder.WriteString("persistent_keepalive_interval=5\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", persistentKeepalive)) config := configBuilder.String() logger.Debug("Configuring peer with config: %s", config) @@ -134,6 +134,24 @@ func RemoveAllowedIP(dev *device.Device, publicKey string, remainingAllowedIPs [ return nil } +// UpdatePersistentKeepalive updates the persistent keepalive interval for a peer without recreating it +func UpdatePersistentKeepalive(dev *device.Device, publicKey string, interval int) error { + var configBuilder strings.Builder + configBuilder.WriteString(fmt.Sprintf("public_key=%s\n", util.FixKey(publicKey))) + configBuilder.WriteString("update_only=true\n") + configBuilder.WriteString(fmt.Sprintf("persistent_keepalive_interval=%d\n", interval)) + + config := configBuilder.String() + logger.Debug("Updating persistent keepalive for peer with config: %s", config) + + err := dev.IpcSet(config) + if err != nil { + return fmt.Errorf("failed to update persistent keepalive for WireGuard peer: %v", err) + } + + return nil +} + func formatEndpoint(endpoint string) string { if strings.Contains(endpoint, ":") { return endpoint