Refactor operation

Former-commit-id: 4f09d122bb
This commit is contained in:
Owen
2026-01-14 11:58:12 -08:00
parent 0e8315b149
commit c86df2c041
5 changed files with 82 additions and 153 deletions

View File

@@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
MiddleDev: o.middleDev,
LocalIP: interfaceIP,
SharedBind: o.sharedBind,
WSClient: o.olmClient,
WSClient: o.websocket,
APIServer: o.apiServer,
})

View File

@@ -189,9 +189,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry
o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
}, 1*time.Second)
}, 1*time.Second, 10)
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
}

View File

@@ -41,7 +41,7 @@ type Olm struct {
dnsProxy *dns.DNSProxy
apiServer *api.API
olmClient *websocket.Client
websocket *websocket.Client
holePunchManager *holepunch.Manager
peerManager *peers.PeerManager
// Power mode management
@@ -57,10 +57,11 @@ type Olm struct {
tunnelConfig TunnelConfig
stopRegister func()
stopPeerSend func()
updateRegister func(newData any)
stopPing chan struct{}
stopServerPing func()
stopPeerSend func()
}
// initTunnelInfo creates the shared UDP socket and holepunch manager.
@@ -270,9 +271,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
o.tunnelCancel = cancel
// Recreate channels for this tunnel session
o.stopPing = make(chan struct{})
var (
id = config.ID
secret = config.Secret
@@ -328,6 +326,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
o.apiServer.SetConnectionStatus(true)
// restart the ping if we need to
if o.stopServerPing == nil {
o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{
"timestamp": time.Now().Unix(),
"userToken": olmClient.GetConfig().UserToken,
}, 30*time.Second, -1) // -1 means dont time out with the max attempts
}
if o.connected {
logger.Debug("Already connected, skipping registration")
return nil
@@ -347,7 +353,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"olmAgent": o.olmConfig.Agent,
"orgId": config.OrgID,
"userToken": userToken,
}, 1*time.Second)
}, 1*time.Second, 10)
// Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil {
@@ -355,8 +361,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
}
}
go o.keepSendingPing(olmClient)
return nil
})
@@ -416,7 +420,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
}
defer func() { _ = olmClient.Close() }()
o.olmClient = olmClient
o.websocket = olmClient
// Wait for context cancellation
<-tunnelCtx.Done()
@@ -435,9 +439,9 @@ func (o *Olm) Close() {
o.holePunchManager = nil
}
if o.stopPing != nil {
close(o.stopPing)
o.stopPing = nil
if o.stopServerPing != nil {
o.stopServerPing()
o.stopServerPing = nil
}
if o.stopRegister != nil {
@@ -515,9 +519,9 @@ func (o *Olm) StopTunnel() error {
}
// Close the websocket connection
if o.olmClient != nil {
_ = o.olmClient.Close()
o.olmClient = nil
if o.websocket != nil {
_ = o.websocket.Close()
o.websocket = nil
}
o.Close()
@@ -602,25 +606,13 @@ func (o *Olm) SetPowerMode(mode string) error {
if mode == "low" {
// Low Power Mode: Close websocket and reduce monitoring frequency
if o.olmClient != nil {
if o.websocket != nil {
logger.Info("Closing websocket connection for low power mode")
if err := o.olmClient.Close(); err != nil {
if err := o.websocket.Close(); err != nil {
logger.Error("Error closing websocket: %v", err)
}
}
if o.stopPing != nil {
select {
case <-o.stopPing:
default:
close(o.stopPing)
}
}
if o.peerManager != nil {
o.peerManager.Stop()
}
if o.originalPeerInterval == 0 && o.peerManager != nil {
peerMonitor := o.peerManager.GetPeerMonitor()
if peerMonitor != nil {
@@ -639,10 +631,6 @@ func (o *Olm) SetPowerMode(mode string) error {
}
}
if o.peerManager != nil {
o.peerManager.Start()
}
o.currentPowerMode = "low"
logger.Info("Switched to low power mode")
@@ -669,60 +657,19 @@ func (o *Olm) SetPowerMode(mode string) error {
}
}
if o.peerManager != nil {
o.peerManager.Start()
}
logger.Info("Reconnecting websocket for normal power mode")
if o.tunnelConfig.ID != "" && o.tunnelConfig.Secret != "" && o.tunnelConfig.Endpoint != "" {
logger.Info("Reconnecting websocket for normal power mode")
if o.olmClient != nil {
o.olmClient.Close()
}
o.stopPing = make(chan struct{})
var (
id = o.tunnelConfig.ID
secret = o.tunnelConfig.Secret
userToken = o.tunnelConfig.UserToken
)
olm, err := websocket.NewClient(
id,
secret,
userToken,
o.tunnelConfig.OrgID,
o.tunnelConfig.Endpoint,
o.tunnelConfig.PingIntervalDuration,
o.tunnelConfig.PingTimeoutDuration,
)
if err != nil {
logger.Error("Failed to create new websocket client: %v", err)
return fmt.Errorf("failed to create new websocket client: %w", err)
}
o.olmClient = olm
olm.OnConnect(func() error {
logger.Info("Websocket Reconnected")
o.apiServer.SetConnectionStatus(true)
go o.keepSendingPing(olm)
return nil
})
if err := olm.Connect(); err != nil {
if o.websocket != nil {
if err := o.websocket.Connect(); err != nil {
logger.Error("Failed to reconnect websocket: %v", err)
return fmt.Errorf("failed to reconnect websocket: %w", err)
}
} else {
logger.Warn("Cannot reconnect websocket: tunnel config not available")
}
o.currentPowerMode = "normal"
logger.Info("Switched to normal power mode")
}
return nil
}
@@ -749,6 +696,14 @@ func (o *Olm) AddDevice(fd uint32) error {
o.middleDev.AddDevice(tdev)
logger.Info("Added device from file descriptor %d", fd)
return nil
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}

View File

@@ -1,56 +0,0 @@
package olm
import (
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/network"
"github.com/fosrl/olm/websocket"
)
func sendPing(olm *websocket.Client) error {
logger.Debug("Sending ping message")
err := olm.SendMessage("olm/ping", map[string]any{
"timestamp": time.Now().Unix(),
"userToken": olm.GetConfig().UserToken,
})
if err != nil {
logger.Error("Failed to send ping message: %v", err)
return err
}
logger.Debug("Sent ping message")
return nil
}
func (o *Olm) keepSendingPing(olm *websocket.Client) {
// Send ping immediately on startup
if err := sendPing(olm); err != nil {
logger.Error("Failed to send initial ping: %v", err)
} else {
logger.Info("Sent initial ping message")
}
// Set up ticker for one minute intervals
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-o.stopPing:
logger.Info("Stopping ping messages")
return
case <-ticker.C:
if err := sendPing(olm); err != nil {
logger.Error("Failed to send periodic ping: %v", err)
}
}
}
}
func GetNetworkSettingsJSON() (string, error) {
return network.GetJSON()
}
func GetNetworkSettingsIncrementor() int {
return network.GetIncrementor()
}

View File

@@ -77,6 +77,7 @@ type Client struct {
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
isDisconnected bool // Flag to track if client is intentionally disconnected
reconnectMux sync.RWMutex
pingInterval time.Duration
pingTimeout time.Duration
@@ -173,6 +174,9 @@ func (c *Client) GetConfig() *Config {
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
if c.isDisconnected {
c.isDisconnected = false
}
go c.connectWithRetry()
return nil
}
@@ -205,9 +209,25 @@ func (c *Client) Close() error {
return nil
}
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
func (c *Client) Disconnect() error {
c.isDisconnected = true
c.setConnected(false)
if c.conn != nil {
c.writeMux.Lock()
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
c.writeMux.Unlock()
err := c.conn.Close()
c.conn = nil
return err
}
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil {
if c.isDisconnected || c.conn == nil {
return fmt.Errorf("not connected")
}
@@ -223,7 +243,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
stopChan := make(chan struct{})
updateChan := make(chan interface{})
var dataMux sync.Mutex
@@ -231,30 +251,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
go func() {
count := 0
maxAttempts := 10
err := c.SendMessage(messageType, currentData) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
send := func() {
if c.isDisconnected || c.conn == nil {
return
}
err := c.SendMessage(messageType, currentData)
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
}
count++
send() // Send immediately
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
if count >= maxAttempts {
if maxAttempts != -1 && count >= maxAttempts {
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
return
}
dataMux.Lock()
err = c.SendMessage(messageType, currentData)
send()
dataMux.Unlock()
if err != nil {
logger.Error("Failed to send message: %v", err)
}
count++
case newData := <-updateChan:
dataMux.Lock()
// Merge newData into currentData if both are maps
@@ -277,6 +299,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
case <-stopChan:
return
}
// Suspend sending if disconnected
for c.isDisconnected {
select {
case <-stopChan:
return
case <-time.After(500 * time.Millisecond):
}
}
}
}()
return func() {
@@ -587,7 +617,7 @@ func (c *Client) pingMonitor() {
case <-c.done:
return
case <-ticker.C:
if c.conn == nil {
if c.isDisconnected || c.conn == nil {
return
}
c.writeMux.Lock()