mirror of
https://github.com/fosrl/olm.git
synced 2026-02-08 05:56:41 +00:00
Merge branch 'power-state' into msg-delivery
This commit is contained in:
@@ -79,6 +79,7 @@ type Client struct {
|
||||
handlersMux sync.RWMutex
|
||||
reconnectInterval time.Duration
|
||||
isConnected bool
|
||||
isDisconnected bool // Flag to track if client is intentionally disconnected
|
||||
reconnectMux sync.RWMutex
|
||||
pingInterval time.Duration
|
||||
pingTimeout time.Duration
|
||||
@@ -91,6 +92,10 @@ type Client struct {
|
||||
configNeedsSave bool // Flag to track if config needs to be saved
|
||||
configVersion int // Latest config version received from server
|
||||
configVersionMux sync.RWMutex
|
||||
token string // Cached authentication token
|
||||
exitNodes []ExitNode // Cached exit nodes from token response
|
||||
tokenMux sync.RWMutex // Protects token and exitNodes
|
||||
forceNewToken bool // Flag to force fetching a new token on next connection
|
||||
}
|
||||
|
||||
type ClientOption func(*Client)
|
||||
@@ -177,6 +182,9 @@ func (c *Client) GetConfig() *Config {
|
||||
|
||||
// Connect establishes the WebSocket connection
|
||||
func (c *Client) Connect() error {
|
||||
if c.isDisconnected {
|
||||
c.isDisconnected = false
|
||||
}
|
||||
go c.connectWithRetry()
|
||||
return nil
|
||||
}
|
||||
@@ -209,9 +217,25 @@ func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Disconnect cleanly closes the websocket connection and suspends message intervals, but allows reconnecting later.
|
||||
func (c *Client) Disconnect() error {
|
||||
c.isDisconnected = true
|
||||
c.setConnected(false)
|
||||
|
||||
if c.conn != nil {
|
||||
c.writeMux.Lock()
|
||||
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||
c.writeMux.Unlock()
|
||||
err := c.conn.Close()
|
||||
c.conn = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessage sends a message through the WebSocket connection
|
||||
func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
if c.conn == nil {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
@@ -220,14 +244,14 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
Data: data,
|
||||
}
|
||||
|
||||
logger.Debug("Sending message: %s, data: %+v", messageType, data)
|
||||
logger.Debug("websocket: Sending message: %s, data: %+v", messageType, data)
|
||||
|
||||
c.writeMux.Lock()
|
||||
defer c.writeMux.Unlock()
|
||||
return c.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) {
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration, maxAttempts int) (stop func(), update func(newData interface{})) {
|
||||
stopChan := make(chan struct{})
|
||||
updateChan := make(chan interface{})
|
||||
var dataMux sync.Mutex
|
||||
@@ -235,30 +259,32 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
||||
|
||||
go func() {
|
||||
count := 0
|
||||
maxAttempts := 10
|
||||
|
||||
err := c.SendMessage(messageType, currentData) // Send immediately
|
||||
if err != nil {
|
||||
logger.Error("Failed to send initial message: %v", err)
|
||||
send := func() {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return
|
||||
}
|
||||
err := c.SendMessage(messageType, currentData)
|
||||
if err != nil {
|
||||
logger.Error("websocket: Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
}
|
||||
count++
|
||||
|
||||
send() // Send immediately
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if count >= maxAttempts {
|
||||
logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
if maxAttempts != -1 && count >= maxAttempts {
|
||||
logger.Info("websocket: SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType)
|
||||
return
|
||||
}
|
||||
dataMux.Lock()
|
||||
err = c.SendMessage(messageType, currentData)
|
||||
send()
|
||||
dataMux.Unlock()
|
||||
if err != nil {
|
||||
logger.Error("Failed to send message: %v", err)
|
||||
}
|
||||
count++
|
||||
case newData := <-updateChan:
|
||||
dataMux.Lock()
|
||||
// Merge newData into currentData if both are maps
|
||||
@@ -281,6 +307,14 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
// Suspend sending if disconnected
|
||||
for c.isDisconnected {
|
||||
select {
|
||||
case <-stopChan:
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
@@ -327,7 +361,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
tlsConfig = &tls.Config{}
|
||||
}
|
||||
tlsConfig.InsecureSkipVerify = true
|
||||
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
logger.Debug("websocket: TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
tokenData := map[string]interface{}{
|
||||
@@ -356,7 +390,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
|
||||
|
||||
// print out the request for debugging
|
||||
logger.Debug("Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
logger.Debug("websocket: Requesting token from %s with body: %s", req.URL.String(), string(jsonData))
|
||||
|
||||
// Make the request
|
||||
client := &http.Client{}
|
||||
@@ -373,7 +407,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
logger.Error("websocket: Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body))
|
||||
|
||||
// Return AuthError for 401/403 status codes
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
@@ -389,7 +423,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
|
||||
var tokenResp TokenResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
|
||||
logger.Error("Failed to decode token response.")
|
||||
logger.Error("websocket: Failed to decode token response.")
|
||||
return "", nil, fmt.Errorf("failed to decode token response: %w", err)
|
||||
}
|
||||
|
||||
@@ -401,7 +435,7 @@ func (c *Client) getToken() (string, []ExitNode, error) {
|
||||
return "", nil, fmt.Errorf("received empty token from server")
|
||||
}
|
||||
|
||||
logger.Debug("Received token: %s", tokenResp.Data.Token)
|
||||
logger.Debug("websocket: Received token: %s", tokenResp.Data.Token)
|
||||
|
||||
return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil
|
||||
}
|
||||
@@ -427,7 +461,7 @@ func (c *Client) connectWithRetry() {
|
||||
continue
|
||||
}
|
||||
// For other errors (5xx, network issues), continue retrying
|
||||
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
logger.Error("websocket: Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
|
||||
time.Sleep(c.reconnectInterval)
|
||||
continue
|
||||
}
|
||||
@@ -437,15 +471,25 @@ func (c *Client) connectWithRetry() {
|
||||
}
|
||||
|
||||
func (c *Client) establishConnection() error {
|
||||
// Get token for authentication
|
||||
token, exitNodes, err := c.getToken()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
// Get token for authentication - reuse cached token unless forced to get new one
|
||||
c.tokenMux.Lock()
|
||||
needNewToken := c.token == "" || c.forceNewToken
|
||||
if needNewToken {
|
||||
token, exitNodes, err := c.getToken()
|
||||
if err != nil {
|
||||
c.tokenMux.Unlock()
|
||||
return fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
c.token = token
|
||||
c.exitNodes = exitNodes
|
||||
c.forceNewToken = false
|
||||
|
||||
if c.onTokenUpdate != nil {
|
||||
c.onTokenUpdate(token, exitNodes)
|
||||
if c.onTokenUpdate != nil {
|
||||
c.onTokenUpdate(token, exitNodes)
|
||||
}
|
||||
}
|
||||
token := c.token
|
||||
c.tokenMux.Unlock()
|
||||
|
||||
// Parse the base URL to determine protocol and hostname
|
||||
baseURL, err := url.Parse(c.baseURL)
|
||||
@@ -480,7 +524,7 @@ func (c *Client) establishConnection() error {
|
||||
|
||||
// Use new TLS configuration method
|
||||
if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Setting up TLS configuration for WebSocket connection")
|
||||
logger.Info("websocket: Setting up TLS configuration for WebSocket connection")
|
||||
tlsConfig, err := c.setupTLS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup TLS configuration: %w", err)
|
||||
@@ -494,11 +538,23 @@ func (c *Client) establishConnection() error {
|
||||
dialer.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
dialer.TLSClientConfig.InsecureSkipVerify = true
|
||||
logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
logger.Debug("websocket: WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
|
||||
}
|
||||
|
||||
conn, _, err := dialer.Dial(u.String(), nil)
|
||||
conn, resp, err := dialer.Dial(u.String(), nil)
|
||||
if err != nil {
|
||||
// Check if this is an unauthorized error (401)
|
||||
if resp != nil && resp.StatusCode == http.StatusUnauthorized {
|
||||
logger.Error("websocket: WebSocket connection rejected with 401 Unauthorized")
|
||||
// Force getting a new token on next reconnect attempt
|
||||
c.tokenMux.Lock()
|
||||
c.forceNewToken = true
|
||||
c.tokenMux.Unlock()
|
||||
return &AuthError{
|
||||
StatusCode: http.StatusUnauthorized,
|
||||
Message: "WebSocket connection unauthorized",
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("failed to connect to WebSocket: %w", err)
|
||||
}
|
||||
|
||||
@@ -512,7 +568,7 @@ func (c *Client) establishConnection() error {
|
||||
|
||||
if c.onConnect != nil {
|
||||
if err := c.onConnect(); err != nil {
|
||||
logger.Error("OnConnect callback failed: %v", err)
|
||||
logger.Error("websocket: OnConnect callback failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -525,9 +581,9 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Handle new separate certificate configuration
|
||||
if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" {
|
||||
logger.Info("Loading separate certificate files for mTLS")
|
||||
logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
logger.Info("websocket: Loading separate certificate files for mTLS")
|
||||
logger.Debug("websocket: Client cert: %s", c.tlsConfig.ClientCertFile)
|
||||
logger.Debug("websocket: Client key: %s", c.tlsConfig.ClientKeyFile)
|
||||
|
||||
// Load client certificate and key
|
||||
cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile)
|
||||
@@ -538,7 +594,7 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Load CA certificates for remote validation if specified
|
||||
if len(c.tlsConfig.CAFiles) > 0 {
|
||||
logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
logger.Debug("websocket: Loading CA certificates: %v", c.tlsConfig.CAFiles)
|
||||
caCertPool := x509.NewCertPool()
|
||||
for _, caFile := range c.tlsConfig.CAFiles {
|
||||
caCert, err := os.ReadFile(caFile)
|
||||
@@ -564,13 +620,13 @@ func (c *Client) setupTLS() (*tls.Config, error) {
|
||||
|
||||
// Fallback to existing PKCS12 implementation for backward compatibility
|
||||
if c.tlsConfig.PKCS12File != "" {
|
||||
logger.Info("Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
logger.Info("websocket: Loading PKCS12 certificate for mTLS (deprecated)")
|
||||
return c.setupPKCS12TLS()
|
||||
}
|
||||
|
||||
// Legacy fallback using config.TlsClientCert
|
||||
if c.config.TlsClientCert != "" {
|
||||
logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
logger.Info("websocket: Loading legacy PKCS12 certificate for mTLS (deprecated)")
|
||||
return loadClientCertificate(c.config.TlsClientCert)
|
||||
}
|
||||
|
||||
@@ -592,7 +648,7 @@ func (c *Client) pingMonitor() {
|
||||
case <-c.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
if c.conn == nil {
|
||||
if c.isDisconnected || c.conn == nil {
|
||||
return
|
||||
}
|
||||
// Send application-level ping with config version
|
||||
@@ -616,7 +672,7 @@ func (c *Client) pingMonitor() {
|
||||
// Expected during shutdown
|
||||
return
|
||||
default:
|
||||
logger.Error("Ping failed: %v", err)
|
||||
logger.Error("websocket: Ping failed: %v", err)
|
||||
c.reconnect()
|
||||
return
|
||||
}
|
||||
@@ -665,18 +721,24 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
||||
var msg WSMessage
|
||||
err := c.conn.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
// Check if we're shutting down before logging error
|
||||
// Check if we're shutting down or explicitly disconnected before logging error
|
||||
select {
|
||||
case <-c.done:
|
||||
// Expected during shutdown, don't log as error
|
||||
logger.Debug("WebSocket connection closed during shutdown")
|
||||
logger.Debug("websocket: connection closed during shutdown")
|
||||
return
|
||||
default:
|
||||
// Check if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: connection closed: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
// Unexpected error during normal operation
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) {
|
||||
logger.Error("WebSocket read error: %v", err)
|
||||
logger.Error("websocket: read error: %v", err)
|
||||
} else {
|
||||
logger.Debug("WebSocket connection closed: %v", err)
|
||||
logger.Debug("websocket: connection closed: %v", err)
|
||||
}
|
||||
return // triggers reconnect via defer
|
||||
}
|
||||
@@ -703,6 +765,12 @@ func (c *Client) reconnect() {
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
// Don't reconnect if explicitly disconnected
|
||||
if c.isDisconnected {
|
||||
logger.Debug("websocket: websocket: Not reconnecting: client was explicitly disconnected")
|
||||
return
|
||||
}
|
||||
|
||||
// Only reconnect if we're not shutting down
|
||||
select {
|
||||
case <-c.done:
|
||||
@@ -720,7 +788,7 @@ func (c *Client) setConnected(status bool) {
|
||||
|
||||
// LoadClientCertificate Helper method to load client certificates (PKCS12 format)
|
||||
func loadClientCertificate(p12Path string) (*tls.Config, error) {
|
||||
logger.Info("Loading tls-client-cert %s", p12Path)
|
||||
logger.Info("websocket: Loading tls-client-cert %s", p12Path)
|
||||
// Read the PKCS12 file
|
||||
p12Data, err := os.ReadFile(p12Path)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user