mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
@@ -154,7 +154,7 @@ func (o *Olm) handleConnect(msg websocket.WSMessage) {
|
|||||||
MiddleDev: o.middleDev,
|
MiddleDev: o.middleDev,
|
||||||
LocalIP: interfaceIP,
|
LocalIP: interfaceIP,
|
||||||
SharedBind: o.sharedBind,
|
SharedBind: o.sharedBind,
|
||||||
WSClient: o.olmClient,
|
WSClient: o.websocket,
|
||||||
APIServer: o.apiServer,
|
APIServer: o.apiServer,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -189,9 +189,9 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
o.holePunchManager.ResetInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||||
|
|
||||||
// Send handshake acknowledgment back to server with retry
|
// Send handshake acknowledgment back to server with retry
|
||||||
o.stopPeerSend, _ = o.olmClient.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
"siteId": handshakeData.SiteId,
|
"siteId": handshakeData.SiteId,
|
||||||
}, 1*time.Second)
|
}, 1*time.Second, 10)
|
||||||
|
|
||||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||||
}
|
}
|
||||||
|
|||||||
109
olm/olm.go
109
olm/olm.go
@@ -41,7 +41,7 @@ type Olm struct {
|
|||||||
|
|
||||||
dnsProxy *dns.DNSProxy
|
dnsProxy *dns.DNSProxy
|
||||||
apiServer *api.API
|
apiServer *api.API
|
||||||
olmClient *websocket.Client
|
websocket *websocket.Client
|
||||||
holePunchManager *holepunch.Manager
|
holePunchManager *holepunch.Manager
|
||||||
peerManager *peers.PeerManager
|
peerManager *peers.PeerManager
|
||||||
// Power mode management
|
// Power mode management
|
||||||
@@ -57,10 +57,11 @@ type Olm struct {
|
|||||||
tunnelConfig TunnelConfig
|
tunnelConfig TunnelConfig
|
||||||
|
|
||||||
stopRegister func()
|
stopRegister func()
|
||||||
stopPeerSend func()
|
|
||||||
updateRegister func(newData any)
|
updateRegister func(newData any)
|
||||||
|
|
||||||
stopPing chan struct{}
|
stopServerPing func()
|
||||||
|
|
||||||
|
stopPeerSend func()
|
||||||
}
|
}
|
||||||
|
|
||||||
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
||||||
@@ -270,9 +271,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
|
tunnelCtx, cancel := context.WithCancel(o.olmCtx)
|
||||||
o.tunnelCancel = cancel
|
o.tunnelCancel = cancel
|
||||||
|
|
||||||
// Recreate channels for this tunnel session
|
|
||||||
o.stopPing = make(chan struct{})
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
id = config.ID
|
id = config.ID
|
||||||
secret = config.Secret
|
secret = config.Secret
|
||||||
@@ -328,6 +326,14 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
o.apiServer.SetConnectionStatus(true)
|
o.apiServer.SetConnectionStatus(true)
|
||||||
|
|
||||||
|
// restart the ping if we need to
|
||||||
|
if o.stopServerPing == nil {
|
||||||
|
o.stopServerPing, _ = olmClient.SendMessageInterval("olm/ping", map[string]any{
|
||||||
|
"timestamp": time.Now().Unix(),
|
||||||
|
"userToken": olmClient.GetConfig().UserToken,
|
||||||
|
}, 30*time.Second, -1) // -1 means dont time out with the max attempts
|
||||||
|
}
|
||||||
|
|
||||||
if o.connected {
|
if o.connected {
|
||||||
logger.Debug("Already connected, skipping registration")
|
logger.Debug("Already connected, skipping registration")
|
||||||
return nil
|
return nil
|
||||||
@@ -347,7 +353,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
"olmAgent": o.olmConfig.Agent,
|
"olmAgent": o.olmConfig.Agent,
|
||||||
"orgId": config.OrgID,
|
"orgId": config.OrgID,
|
||||||
"userToken": userToken,
|
"userToken": userToken,
|
||||||
}, 1*time.Second)
|
}, 1*time.Second, 10)
|
||||||
|
|
||||||
// Invoke onRegistered callback if configured
|
// Invoke onRegistered callback if configured
|
||||||
if o.olmConfig.OnRegistered != nil {
|
if o.olmConfig.OnRegistered != nil {
|
||||||
@@ -355,8 +361,6 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
go o.keepSendingPing(olmClient)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -416,7 +420,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
}
|
}
|
||||||
defer func() { _ = olmClient.Close() }()
|
defer func() { _ = olmClient.Close() }()
|
||||||
|
|
||||||
o.olmClient = olmClient
|
o.websocket = olmClient
|
||||||
|
|
||||||
// Wait for context cancellation
|
// Wait for context cancellation
|
||||||
<-tunnelCtx.Done()
|
<-tunnelCtx.Done()
|
||||||
@@ -435,9 +439,9 @@ func (o *Olm) Close() {
|
|||||||
o.holePunchManager = nil
|
o.holePunchManager = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.stopPing != nil {
|
if o.stopServerPing != nil {
|
||||||
close(o.stopPing)
|
o.stopServerPing()
|
||||||
o.stopPing = nil
|
o.stopServerPing = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.stopRegister != nil {
|
if o.stopRegister != nil {
|
||||||
@@ -515,9 +519,9 @@ func (o *Olm) StopTunnel() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close the websocket connection
|
// Close the websocket connection
|
||||||
if o.olmClient != nil {
|
if o.websocket != nil {
|
||||||
_ = o.olmClient.Close()
|
_ = o.websocket.Close()
|
||||||
o.olmClient = nil
|
o.websocket = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
o.Close()
|
o.Close()
|
||||||
@@ -602,25 +606,13 @@ func (o *Olm) SetPowerMode(mode string) error {
|
|||||||
if mode == "low" {
|
if mode == "low" {
|
||||||
// Low Power Mode: Close websocket and reduce monitoring frequency
|
// Low Power Mode: Close websocket and reduce monitoring frequency
|
||||||
|
|
||||||
if o.olmClient != nil {
|
if o.websocket != nil {
|
||||||
logger.Info("Closing websocket connection for low power mode")
|
logger.Info("Closing websocket connection for low power mode")
|
||||||
if err := o.olmClient.Close(); err != nil {
|
if err := o.websocket.Close(); err != nil {
|
||||||
logger.Error("Error closing websocket: %v", err)
|
logger.Error("Error closing websocket: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.stopPing != nil {
|
|
||||||
select {
|
|
||||||
case <-o.stopPing:
|
|
||||||
default:
|
|
||||||
close(o.stopPing)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if o.peerManager != nil {
|
|
||||||
o.peerManager.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if o.originalPeerInterval == 0 && o.peerManager != nil {
|
if o.originalPeerInterval == 0 && o.peerManager != nil {
|
||||||
peerMonitor := o.peerManager.GetPeerMonitor()
|
peerMonitor := o.peerManager.GetPeerMonitor()
|
||||||
if peerMonitor != nil {
|
if peerMonitor != nil {
|
||||||
@@ -639,10 +631,6 @@ func (o *Olm) SetPowerMode(mode string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.peerManager != nil {
|
|
||||||
o.peerManager.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
o.currentPowerMode = "low"
|
o.currentPowerMode = "low"
|
||||||
logger.Info("Switched to low power mode")
|
logger.Info("Switched to low power mode")
|
||||||
|
|
||||||
@@ -669,54 +657,13 @@ func (o *Olm) SetPowerMode(mode string) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.peerManager != nil {
|
|
||||||
o.peerManager.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
if o.tunnelConfig.ID != "" && o.tunnelConfig.Secret != "" && o.tunnelConfig.Endpoint != "" {
|
|
||||||
logger.Info("Reconnecting websocket for normal power mode")
|
logger.Info("Reconnecting websocket for normal power mode")
|
||||||
|
|
||||||
if o.olmClient != nil {
|
if o.websocket != nil {
|
||||||
o.olmClient.Close()
|
if err := o.websocket.Connect(); err != nil {
|
||||||
}
|
|
||||||
|
|
||||||
o.stopPing = make(chan struct{})
|
|
||||||
|
|
||||||
var (
|
|
||||||
id = o.tunnelConfig.ID
|
|
||||||
secret = o.tunnelConfig.Secret
|
|
||||||
userToken = o.tunnelConfig.UserToken
|
|
||||||
)
|
|
||||||
|
|
||||||
olm, err := websocket.NewClient(
|
|
||||||
id,
|
|
||||||
secret,
|
|
||||||
userToken,
|
|
||||||
o.tunnelConfig.OrgID,
|
|
||||||
o.tunnelConfig.Endpoint,
|
|
||||||
o.tunnelConfig.PingIntervalDuration,
|
|
||||||
o.tunnelConfig.PingTimeoutDuration,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to create new websocket client: %v", err)
|
|
||||||
return fmt.Errorf("failed to create new websocket client: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
o.olmClient = olm
|
|
||||||
|
|
||||||
olm.OnConnect(func() error {
|
|
||||||
logger.Info("Websocket Reconnected")
|
|
||||||
o.apiServer.SetConnectionStatus(true)
|
|
||||||
go o.keepSendingPing(olm)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := olm.Connect(); err != nil {
|
|
||||||
logger.Error("Failed to reconnect websocket: %v", err)
|
logger.Error("Failed to reconnect websocket: %v", err)
|
||||||
return fmt.Errorf("failed to reconnect websocket: %w", err)
|
return fmt.Errorf("failed to reconnect websocket: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
logger.Warn("Cannot reconnect websocket: tunnel config not available")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
o.currentPowerMode = "normal"
|
o.currentPowerMode = "normal"
|
||||||
@@ -752,3 +699,11 @@ func (o *Olm) AddDevice(fd uint32) error {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetNetworkSettingsJSON() (string, error) {
|
||||||
|
return network.GetJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetNetworkSettingsIncrementor() int {
|
||||||
|
return network.GetIncrementor()
|
||||||
|
}
|
||||||
|
|||||||
56
olm/ping.go
56
olm/ping.go
@@ -1,56 +0,0 @@
|
|||||||
package olm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
|
||||||
"github.com/fosrl/newt/network"
|
|
||||||
"github.com/fosrl/olm/websocket"
|
|
||||||
)
|
|
||||||
|
|
||||||
func sendPing(olm *websocket.Client) error {
|
|
||||||
logger.Debug("Sending ping message")
|
|
||||||
err := olm.SendMessage("olm/ping", map[string]any{
|
|
||||||
"timestamp": time.Now().Unix(),
|
|
||||||
"userToken": olm.GetConfig().UserToken,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send ping message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
logger.Debug("Sent ping message")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *Olm) keepSendingPing(olm *websocket.Client) {
|
|
||||||
// Send ping immediately on startup
|
|
||||||
if err := sendPing(olm); err != nil {
|
|
||||||
logger.Error("Failed to send initial ping: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Sent initial ping message")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up ticker for one minute intervals
|
|
||||||
ticker := time.NewTicker(30 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-o.stopPing:
|
|
||||||
logger.Info("Stopping ping messages")
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if err := sendPing(olm); err != nil {
|
|
||||||
logger.Error("Failed to send periodic ping: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetNetworkSettingsJSON() (string, error) {
|
|
||||||
return network.GetJSON()
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetNetworkSettingsIncrementor() int {
|
|
||||||
return network.GetIncrementor()
|
|
||||||
}
|
|
||||||
@@ -77,6 +77,7 @@ type Client struct {
|
|||||||
handlersMux sync.RWMutex
|
handlersMux sync.RWMutex
|
||||||
reconnectInterval time.Duration
|
reconnectInterval time.Duration
|
||||||
isConnected bool
|
isConnected bool
|
||||||
|
isDisconnected bool // Flag to track if client is intentionally disconnected
|
||||||
reconnectMux sync.RWMutex
|
reconnectMux sync.RWMutex
|
||||||
pingInterval time.Duration
|
pingInterval time.Duration
|
||||||
pingTimeout time.Duration
|
pingTimeout time.Duration
|
||||||
@@ -173,6 +174,9 @@ func (c *Client) GetConfig() *Config {
|
|||||||
|
|
||||||
// Connect establishes the WebSocket connection
|
// Connect establishes the WebSocket connection
|
||||||
func (c *Client) Connect() error {
|
func (c *Client) Connect() error {
|
||||||
|
if c.isDisconnected {
|
||||||
|
c.isDisconnected = false
|
||||||
|
}
|
||||||
go c.connectWithRetry()
|
go c.connectWithRetry()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -205,9 +209,25 @@ func (c *Client) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
|
||||||
|
func (c *Client) Disconnect() error {
|
||||||
|
c.isDisconnected = true
|
||||||
|
c.setConnected(false)
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
c.writeMux.Lock()
|
||||||
|
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
|
c.writeMux.Unlock()
|
||||||
|
err := c.conn.Close()
|
||||||
|
c.conn = nil
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// SendMessage sends a message through the WebSocket connection
|
// SendMessage sends a message through the WebSocket connection
|
||||||
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||||
if c.conn == nil {
|
if c.isDisconnected || c.conn == nil {
|
||||||
return fmt.Errorf("not connected")
|
return fmt.Errorf("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,7 +243,7 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
|||||||
return c.conn.WriteJSON(msg)
|
return c.conn.WriteJSON(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
|
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
|
||||||
stopChan := make(chan struct{})
|
stopChan := make(chan struct{})
|
||||||
updateChan := make(chan interface{})
|
updateChan := make(chan interface{})
|
||||||
var dataMux sync.Mutex
|
var dataMux sync.Mutex
|
||||||
@@ -231,30 +251,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
count := 0
|
count := 0
|
||||||
maxAttempts := 10
|
|
||||||
|
|
||||||
err := c.SendMessage(messageType, currentData) // Send immediately
|
send := func() {
|
||||||
|
if c.isDisconnected || c.conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err := c.SendMessage(messageType, currentData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send initial message: %v", err)
|
logger.Error("Failed to send message: %v", err)
|
||||||
}
|
}
|
||||||
count++
|
count++
|
||||||
|
}
|
||||||
|
|
||||||
|
send() // Send immediately
|
||||||
|
|
||||||
ticker := time.NewTicker(interval)
|
ticker := time.NewTicker(interval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if count >= maxAttempts {
|
if maxAttempts != -1 && count >= maxAttempts {
|
||||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
dataMux.Lock()
|
dataMux.Lock()
|
||||||
err = c.SendMessage(messageType, currentData)
|
send()
|
||||||
dataMux.Unlock()
|
dataMux.Unlock()
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send message: %v", err)
|
|
||||||
}
|
|
||||||
count++
|
|
||||||
case newData := <-updateChan:
|
case newData := <-updateChan:
|
||||||
dataMux.Lock()
|
dataMux.Lock()
|
||||||
// Merge newData into currentData if both are maps
|
// Merge newData into currentData if both are maps
|
||||||
@@ -277,6 +299,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
|||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Suspend sending if disconnected
|
||||||
|
for c.isDisconnected {
|
||||||
|
select {
|
||||||
|
case <-stopChan:
|
||||||
|
return
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
return func() {
|
return func() {
|
||||||
@@ -587,7 +617,7 @@ func (c *Client) pingMonitor() {
|
|||||||
case <-c.done:
|
case <-c.done:
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if c.conn == nil {
|
if c.isDisconnected || c.conn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.writeMux.Lock()
|
c.writeMux.Lock()
|
||||||
|
|||||||
Reference in New Issue
Block a user