From 4f09d122bb53f3a32f73624edb2ca7d1c26b3175 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 14 Jan 2026 11:58:12 -0800 Subject: [PATCH] Refactor operation --- olm/connect.go | 2 +- olm/data.go | 4 +- olm/olm.go | 115 ++++++++++++++------------------------------ olm/ping.go | 56 --------------------- websocket/client.go | 58 ++++++++++++++++------ 5 files changed, 82 insertions(+), 153 deletions(-) delete mode 100644 olm/ping.go diff --git a/olm/connect.go b/olm/connect.go index 568c731..a610ea4 100644 --- a/olm/connect.go +++ b/olm/connect.go @@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) { MiddleDev: o.middleDev, LocalIP: interfaceIP, SharedBind: o.sharedBind, - WSClient: o.olmClient, + WSClient: o.websocket, APIServer: o.apiServer, }) diff --git a/olm/data.go b/olm/data.go index 9c8d33f..93e64d0 100644 --- a/olm/data.go +++ b/olm/data.go @@ -189,9 +189,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ "siteId": handshakeData.SiteId, - }, 1*time.Second) + }, 1*time.Second, 10) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) } diff --git a/olm/olm.go b/olm/olm.go index 15e3a6a..63b53a7 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -41,7 +41,7 @@ type Olm struct { dnsProxy *dns.DNSProxy apiServer *api.API - olmClient *websocket.Client + websocket *websocket.Client holePunchManager *holepunch.Manager peerManager *peers.PeerManager // Power mode management @@ -57,10 +57,11 @@ type Olm struct { tunnelConfig TunnelConfig stopRegister func() - stopPeerSend func() updateRegister func(newData any) - stopPing chan struct{} + stopServerPing func() + + stopPeerSend func() } // initTunnelInfo creates the shared UDP socket and holepunch manager. @@ -270,9 +271,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { tunnelCtx, cancel := context.WithCancel(o.olmCtx) o.tunnelCancel = cancel - // Recreate channels for this tunnel session - o.stopPing = make(chan struct{}) - var ( id = config.ID secret = config.Secret @@ -328,6 +326,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) { o.apiServer.SetConnectionStatus(true) + // restart the ping if we need to + if o.stopServerPing == nil { + o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{ + "timestamp": time.Now().Unix(), + "userToken": olmClient.GetConfig().UserToken, + }, 30*time.Second, -1) // -1 means dont time out with the max attempts + } + if o.connected { logger.Debug("Already connected, skipping registration") return nil @@ -347,7 +353,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { "olmAgent": o.olmConfig.Agent, "orgId": config.OrgID, "userToken": userToken, - }, 1*time.Second) + }, 1*time.Second, 10) // Invoke onRegistered callback if configured if o.olmConfig.OnRegistered != nil { @@ -355,8 +361,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } } - go o.keepSendingPing(olmClient) - return nil }) @@ -416,7 +420,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { } defer func() { _ = olmClient.Close() }() - o.olmClient = olmClient + o.websocket = olmClient // Wait for context cancellation <-tunnelCtx.Done() @@ -435,9 +439,9 @@ func (o *Olm) Close() { o.holePunchManager = nil } - if o.stopPing != nil { - close(o.stopPing) - o.stopPing = nil + if o.stopServerPing != nil { + o.stopServerPing() + o.stopServerPing = nil } if o.stopRegister != nil { @@ -515,9 +519,9 @@ func (o *Olm) StopTunnel() error { } // Close the websocket connection - if o.olmClient != nil { - _ = o.olmClient.Close() - o.olmClient = nil + if o.websocket != nil { + _ = o.websocket.Close() + o.websocket = nil } o.Close() @@ -602,25 +606,13 @@ func (o *Olm) SetPowerMode(mode string) error { if mode == "low" { // Low Power Mode: Close websocket and reduce monitoring frequency - if o.olmClient != nil { + if o.websocket != nil { logger.Info("Closing websocket connection for low power mode") - if err := o.olmClient.Close(); err != nil { + if err := o.websocket.Close(); err != nil { logger.Error("Error closing websocket: %v", err) } } - if o.stopPing != nil { - select { - case <-o.stopPing: - default: - close(o.stopPing) - } - } - - if o.peerManager != nil { - o.peerManager.Stop() - } - if o.originalPeerInterval == 0 && o.peerManager != nil { peerMonitor := o.peerManager.GetPeerMonitor() if peerMonitor != nil { @@ -639,10 +631,6 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.peerManager != nil { - o.peerManager.Start() - } - o.currentPowerMode = "low" logger.Info("Switched to low power mode") @@ -669,60 +657,19 @@ func (o *Olm) SetPowerMode(mode string) error { } } - if o.peerManager != nil { - o.peerManager.Start() - } + logger.Info("Reconnecting websocket for normal power mode") - if o.tunnelConfig.ID != "" && o.tunnelConfig.Secret != "" && o.tunnelConfig.Endpoint != "" { - logger.Info("Reconnecting websocket for normal power mode") - - if o.olmClient != nil { - o.olmClient.Close() - } - - o.stopPing = make(chan struct{}) - - var ( - id = o.tunnelConfig.ID - secret = o.tunnelConfig.Secret - userToken = o.tunnelConfig.UserToken - ) - - olm, err := websocket.NewClient( - id, - secret, - userToken, - o.tunnelConfig.OrgID, - o.tunnelConfig.Endpoint, - o.tunnelConfig.PingIntervalDuration, - o.tunnelConfig.PingTimeoutDuration, - ) - if err != nil { - logger.Error("Failed to create new websocket client: %v", err) - return fmt.Errorf("failed to create new websocket client: %w", err) - } - - o.olmClient = olm - - olm.OnConnect(func() error { - logger.Info("Websocket Reconnected") - o.apiServer.SetConnectionStatus(true) - go o.keepSendingPing(olm) - return nil - }) - - if err := olm.Connect(); err != nil { + if o.websocket != nil { + if err := o.websocket.Connect(); err != nil { logger.Error("Failed to reconnect websocket: %v", err) return fmt.Errorf("failed to reconnect websocket: %w", err) } - } else { - logger.Warn("Cannot reconnect websocket: tunnel config not available") } o.currentPowerMode = "normal" logger.Info("Switched to normal power mode") } - + return nil } @@ -749,6 +696,14 @@ func (o *Olm) AddDevice(fd uint32) error { o.middleDev.AddDevice(tdev) logger.Info("Added device from file descriptor %d", fd) - + return nil } + +func GetNetworkSettingsJSON() (string, error) { + return network.GetJSON() +} + +func GetNetworkSettingsIncrementor() int { + return network.GetIncrementor() +} diff --git a/olm/ping.go b/olm/ping.go deleted file mode 100644 index fd7706a..0000000 --- a/olm/ping.go +++ /dev/null @@ -1,56 +0,0 @@ -package olm - -import ( - "time" - - "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/network" - "github.com/fosrl/olm/websocket" -) - -func sendPing(olm *websocket.Client) error { - logger.Debug("Sending ping message") - err := olm.SendMessage("olm/ping", map[string]any{ - "timestamp": time.Now().Unix(), - "userToken": olm.GetConfig().UserToken, - }) - if err != nil { - logger.Error("Failed to send ping message: %v", err) - return err - } - logger.Debug("Sent ping message") - return nil -} - -func (o *Olm) keepSendingPing(olm *websocket.Client) { - // Send ping immediately on startup - if err := sendPing(olm); err != nil { - logger.Error("Failed to send initial ping: %v", err) - } else { - logger.Info("Sent initial ping message") - } - - // Set up ticker for one minute intervals - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-o.stopPing: - logger.Info("Stopping ping messages") - return - case <-ticker.C: - if err := sendPing(olm); err != nil { - logger.Error("Failed to send periodic ping: %v", err) - } - } - } -} - -func GetNetworkSettingsJSON() (string, error) { - return network.GetJSON() -} - -func GetNetworkSettingsIncrementor() int { - return network.GetIncrementor() -} diff --git a/websocket/client.go b/websocket/client.go index 1c5afaf..34eea35 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -77,6 +77,7 @@ type Client struct { handlersMux sync.RWMutex reconnectInterval time.Duration isConnected bool + isDisconnected bool // Flag to track if client is intentionally disconnected reconnectMux sync.RWMutex pingInterval time.Duration pingTimeout time.Duration @@ -173,6 +174,9 @@ func (c *Client) GetConfig() *Config { // Connect establishes the WebSocket connection func (c *Client) Connect() error { + if c.isDisconnected { + c.isDisconnected = false + } go c.connectWithRetry() return nil } @@ -205,9 +209,25 @@ func (c *Client) Close() error { return nil } +// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later. +func (c *Client) Disconnect() error { + c.isDisconnected = true + c.setConnected(false) + + if c.conn != nil { + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + err := c.conn.Close() + c.conn = nil + return err + } + return nil +} + // SendMessage sends a message through the WebSocket connection func (c *Client) SendMessage(messageType string, data interface{}) error { - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return fmt.Errorf("not connected") } @@ -223,7 +243,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) updateChan := make(chan interface{}) var dataMux sync.Mutex @@ -231,30 +251,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter go func() { count := 0 - maxAttempts := 10 - err := c.SendMessage(messageType, currentData) // Send immediately - if err != nil { - logger.Error("Failed to send initial message: %v", err) + send := func() { + if c.isDisconnected || c.conn == nil { + return + } + err := c.SendMessage(messageType, currentData) + if err != nil { + logger.Error("Failed to send message: %v", err) + } + count++ } - count++ + + send() // Send immediately ticker := time.NewTicker(interval) defer ticker.Stop() for { select { case <-ticker.C: - if count >= maxAttempts { + if maxAttempts != -1 && count >= maxAttempts { logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } dataMux.Lock() - err = c.SendMessage(messageType, currentData) + send() dataMux.Unlock() - if err != nil { - logger.Error("Failed to send message: %v", err) - } - count++ case newData := <-updateChan: dataMux.Lock() // Merge newData into currentData if both are maps @@ -277,6 +299,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter case <-stopChan: return } + // Suspend sending if disconnected + for c.isDisconnected { + select { + case <-stopChan: + return + case <-time.After(500 * time.Millisecond): + } + } } }() return func() { @@ -587,7 +617,7 @@ func (c *Client) pingMonitor() { case <-c.done: return case <-ticker.C: - if c.conn == nil { + if c.isDisconnected || c.conn == nil { return } c.writeMux.Lock()