mirror of
https://github.com/fosrl/olm.git
synced 2026-02-24 05:46:46 +00:00
@@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
||||
MiddleDev: o.middleDev,
|
||||
LocalIP: interfaceIP,
|
||||
SharedBind: o.sharedBind,
|
||||
WSClient: o.olmClient,
|
||||
WSClient: o.websocket,
|
||||
APIServer: o.apiServer,
|
||||
})
|
||||
|
||||
|
||||
@@ -186,12 +186,12 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||
|
||||
// Send handshake acknowledgment back to server with retry
|
||||
o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": handshakeData.SiteId,
|
||||
}, 1*time.Second)
|
||||
}, 1*time.Second, 10)
|
||||
|
||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||
}
|
||||
|
||||
199
olm/olm.go
199
olm/olm.go
@@ -42,9 +42,14 @@ type Olm struct {
|
||||
|
||||
dnsProxy *dns.DNSProxy
|
||||
apiServer *api.API
|
||||
olmClient *websocket.Client
|
||||
websocket *websocket.Client
|
||||
holePunchManager *holepunch.Manager
|
||||
peerManager *peers.PeerManager
|
||||
// Power mode management
|
||||
currentPowerMode string
|
||||
powerModeMu sync.Mutex
|
||||
wakeUpTimer *time.Timer
|
||||
wakeUpDebounce time.Duration
|
||||
|
||||
olmCtx context.Context
|
||||
tunnelCancel context.CancelFunc
|
||||
@@ -58,10 +63,11 @@ type Olm struct {
|
||||
metaMu sync.Mutex
|
||||
|
||||
stopRegister func()
|
||||
stopPeerSend func()
|
||||
updateRegister func(newData any)
|
||||
|
||||
stopPing chan struct{}
|
||||
stopServerPing func()
|
||||
|
||||
stopPeerSend func()
|
||||
}
|
||||
|
||||
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
||||
@@ -134,6 +140,10 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||
logFile = file
|
||||
}
|
||||
|
||||
if config.WakeUpDebounce == 0 {
|
||||
config.WakeUpDebounce = 3 * time.Second
|
||||
}
|
||||
|
||||
logger.Debug("Checking permissions for native interface")
|
||||
err := permissions.CheckNativeInterfacePermissions()
|
||||
if err != nil {
|
||||
@@ -285,9 +295,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
|
||||
o.tunnelCancel = cancel
|
||||
|
||||
// Recreate channels for this tunnel session
|
||||
o.stopPing = make(chan struct{})
|
||||
|
||||
var (
|
||||
id = config.ID
|
||||
secret = config.Secret
|
||||
@@ -343,6 +350,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
|
||||
o.apiServer.SetConnectionStatus(true)
|
||||
|
||||
// restart the ping if we need to
|
||||
if o.stopServerPing == nil {
|
||||
o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": olmClient.GetConfig().UserToken,
|
||||
}, 30*time.Second, -1) // -1 means dont time out with the max attempts
|
||||
}
|
||||
|
||||
if o.connected {
|
||||
logger.Debug("Already connected, skipping registration")
|
||||
return nil
|
||||
@@ -364,7 +379,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
"userToken": userToken,
|
||||
"fingerprint": o.fingerprint,
|
||||
"postures": o.postures,
|
||||
}, 1*time.Second)
|
||||
}, 1*time.Second, 10)
|
||||
|
||||
// Invoke onRegistered callback if configured
|
||||
if o.olmConfig.OnRegistered != nil {
|
||||
@@ -372,8 +387,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
}
|
||||
|
||||
go o.keepSendingPing(olmClient)
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
@@ -446,7 +459,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
||||
}
|
||||
defer func() { _ = olmClient.Close() }()
|
||||
|
||||
o.olmClient = olmClient
|
||||
o.websocket = olmClient
|
||||
|
||||
// Wait for context cancellation
|
||||
<-tunnelCtx.Done()
|
||||
@@ -465,9 +478,9 @@ func (o *Olm) Close() {
|
||||
o.holePunchManager = nil
|
||||
}
|
||||
|
||||
if o.stopPing != nil {
|
||||
close(o.stopPing)
|
||||
o.stopPing = nil
|
||||
if o.stopServerPing != nil {
|
||||
o.stopServerPing()
|
||||
o.stopServerPing = nil
|
||||
}
|
||||
|
||||
if o.stopRegister != nil {
|
||||
@@ -545,9 +558,9 @@ func (o *Olm) StopTunnel() error {
|
||||
}
|
||||
|
||||
// Close the websocket connection
|
||||
if o.olmClient != nil {
|
||||
_ = o.olmClient.Close()
|
||||
o.olmClient = nil
|
||||
if o.websocket != nil {
|
||||
_ = o.websocket.Close()
|
||||
o.websocket = nil
|
||||
}
|
||||
|
||||
o.Close()
|
||||
@@ -625,3 +638,157 @@ func (o *Olm) SetPostures(data map[string]any) {
|
||||
|
||||
o.postures = data
|
||||
}
|
||||
|
||||
// SetPowerMode switches between normal and low power modes
|
||||
// In low power mode: websocket is closed (stopping pings) and monitoring intervals are set to 10 minutes
|
||||
// In normal power mode: websocket is reconnected (restarting pings) and monitoring intervals are restored
|
||||
// Wake-up has a 3-second debounce to prevent rapid flip-flopping; sleep is immediate
|
||||
func (o *Olm) SetPowerMode(mode string) error {
|
||||
// Validate mode
|
||||
if mode != "normal" && mode != "low" {
|
||||
return fmt.Errorf("invalid power mode: %s (must be 'normal' or 'low')", mode)
|
||||
}
|
||||
|
||||
o.powerModeMu.Lock()
|
||||
defer o.powerModeMu.Unlock()
|
||||
|
||||
// If already in the requested mode, return early
|
||||
if o.currentPowerMode == mode {
|
||||
// Cancel any pending wake-up timer if we're already in normal mode
|
||||
if mode == "normal" && o.wakeUpTimer != nil {
|
||||
o.wakeUpTimer.Stop()
|
||||
o.wakeUpTimer = nil
|
||||
}
|
||||
logger.Debug("Already in %s power mode", mode)
|
||||
return nil
|
||||
}
|
||||
|
||||
if mode == "low" {
|
||||
// Low Power Mode: Cancel any pending wake-up and immediately go to sleep
|
||||
|
||||
// Cancel pending wake-up timer if any
|
||||
if o.wakeUpTimer != nil {
|
||||
logger.Debug("Cancelling pending wake-up timer")
|
||||
o.wakeUpTimer.Stop()
|
||||
o.wakeUpTimer = nil
|
||||
}
|
||||
|
||||
logger.Info("Switching to low power mode")
|
||||
|
||||
if o.websocket != nil {
|
||||
logger.Info("Disconnecting websocket for low power mode")
|
||||
if err := o.websocket.Disconnect(); err != nil {
|
||||
logger.Error("Error disconnecting websocket: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
lowPowerInterval := 10 * time.Minute
|
||||
|
||||
if o.peerManager != nil {
|
||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.SetPeerInterval(lowPowerInterval, lowPowerInterval)
|
||||
peerMonitor.SetPeerHolepunchInterval(lowPowerInterval, lowPowerInterval)
|
||||
logger.Info("Set monitoring intervals to 10 minutes for low power mode")
|
||||
}
|
||||
o.peerManager.UpdateAllPeersPersistentKeepalive(0) // disable
|
||||
}
|
||||
|
||||
if o.holePunchManager != nil {
|
||||
o.holePunchManager.SetServerHolepunchInterval(lowPowerInterval, lowPowerInterval)
|
||||
}
|
||||
|
||||
o.currentPowerMode = "low"
|
||||
logger.Info("Switched to low power mode")
|
||||
|
||||
} else {
|
||||
// Normal Power Mode: Start debounce timer before actually waking up
|
||||
|
||||
// If there's already a pending wake-up timer, don't start another
|
||||
if o.wakeUpTimer != nil {
|
||||
logger.Debug("Wake-up already pending, ignoring duplicate request")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Info("Wake-up requested, starting %v debounce timer", o.wakeUpDebounce)
|
||||
|
||||
o.wakeUpTimer = time.AfterFunc(o.wakeUpDebounce, func() {
|
||||
o.powerModeMu.Lock()
|
||||
defer o.powerModeMu.Unlock()
|
||||
|
||||
// Clear the timer reference
|
||||
o.wakeUpTimer = nil
|
||||
|
||||
// Double-check we're still in low power mode (could have changed)
|
||||
if o.currentPowerMode == "normal" {
|
||||
logger.Debug("Already in normal mode after debounce, skipping wake-up")
|
||||
return
|
||||
}
|
||||
|
||||
logger.Info("Debounce complete, switching to normal power mode")
|
||||
|
||||
logger.Info("Reconnecting websocket for normal power mode")
|
||||
if o.websocket != nil {
|
||||
if err := o.websocket.Connect(); err != nil {
|
||||
logger.Error("Failed to reconnect websocket: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Restore intervals and reconnect websocket
|
||||
if o.peerManager != nil {
|
||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.ResetPeerHolepunchInterval()
|
||||
peerMonitor.ResetPeerInterval()
|
||||
}
|
||||
|
||||
o.peerManager.UpdateAllPeersPersistentKeepalive(5)
|
||||
}
|
||||
|
||||
if o.holePunchManager != nil {
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
|
||||
o.currentPowerMode = "normal"
|
||||
logger.Info("Switched to normal power mode")
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Olm) AddDevice(fd uint32) error {
|
||||
if o.middleDev == nil {
|
||||
return fmt.Errorf("middle device is not initialized")
|
||||
}
|
||||
|
||||
if o.tunnelConfig.MTU == 0 {
|
||||
return fmt.Errorf("tunnel MTU is not set")
|
||||
}
|
||||
|
||||
tdev, err := olmDevice.CreateTUNFromFD(fd, o.tunnelConfig.MTU)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create TUN device from fd: %v", err)
|
||||
}
|
||||
|
||||
// Update interface name if available
|
||||
if realInterfaceName, err2 := tdev.Name(); err2 == nil {
|
||||
o.tunnelConfig.InterfaceName = realInterfaceName
|
||||
}
|
||||
|
||||
// Replace the existing TUN device in the middle device with the new one
|
||||
o.middleDev.AddDevice(tdev)
|
||||
|
||||
logger.Info("Added device from file descriptor %d", fd)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetNetworkSettingsJSON() (string, error) {
|
||||
return network.GetJSON()
|
||||
}
|
||||
|
||||
func GetNetworkSettingsIncrementor() int {
|
||||
return network.GetIncrementor()
|
||||
}
|
||||
|
||||
@@ -123,7 +123,7 @@ func (o *Olm) handleWgPeerUpdate(msg websocket.WSMessage) {
|
||||
if updateData.Endpoint != "" && updateData.Endpoint != existingPeer.Endpoint {
|
||||
logger.Info("Endpoint changed for site %d, triggering holepunch to refresh NAT mappings", updateData.SiteId)
|
||||
_ = o.holePunchManager.TriggerHolePunch()
|
||||
o.holePunchManager.ResetInterval()
|
||||
o.holePunchManager.ResetServerHolepunchInterval()
|
||||
}
|
||||
|
||||
logger.Info("Successfully updated peer for site %d", updateData.SiteId)
|
||||
|
||||
57
olm/ping.go
57
olm/ping.go
@@ -1,57 +0,0 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/fosrl/newt/network"
|
||||
"github.com/fosrl/olm/websocket"
|
||||
)
|
||||
|
||||
func (o *Olm) sendPing(olm *websocket.Client) error {
|
||||
err := olm.SendMessage("olm/ping", map[string]any{
|
||||
"timestamp": time.Now().Unix(),
|
||||
"userToken": olm.GetConfig().UserToken,
|
||||
"fingerprint": o.fingerprint,
|
||||
"postures": o.postures,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send ping message: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Debug("Sent ping message")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (o *Olm) keepSendingPing(olm *websocket.Client) {
|
||||
// Send ping immediately on startup
|
||||
if err := o.sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send initial ping: %v", err)
|
||||
} else {
|
||||
logger.Info("Sent initial ping message")
|
||||
}
|
||||
|
||||
// Set up ticker for one minute intervals
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-o.stopPing:
|
||||
logger.Info("Stopping ping messages")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := o.sendPing(olm); err != nil {
|
||||
logger.Error("Failed to send periodic ping: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetNetworkSettingsJSON() (string, error) {
|
||||
return network.GetJSON()
|
||||
}
|
||||
|
||||
func GetNetworkSettingsIncrementor() int {
|
||||
return network.GetIncrementor()
|
||||
}
|
||||
@@ -23,6 +23,8 @@ type OlmConfig struct {
|
||||
SocketPath string
|
||||
Version string
|
||||
Agent string
|
||||
|
||||
WakeUpDebounce time.Duration
|
||||
|
||||
// Debugging
|
||||
PprofAddr string // Address to serve pprof on (e.g., "localhost:6060")
|
||||
|
||||
Reference in New Issue
Block a user