Fix small issues; add back ws

This commit is contained in:
Owen
2025-02-21 16:11:57 -05:00
parent eb77be09e2
commit 8b09545cf6
6 changed files with 448 additions and 10 deletions

351
websocket/client.go Normal file
View File

@@ -0,0 +1,351 @@
package websocket
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"github.com/fosrl/newt/logger"
"github.com/gorilla/websocket"
)
type Client struct {
conn *websocket.Conn
config *Config
baseURL string
handlers map[string]MessageHandler
done chan struct{}
handlersMux sync.RWMutex
reconnectInterval time.Duration
isConnected bool
reconnectMux sync.RWMutex
onConnect func() error
}
type ClientOption func(*Client)
type MessageHandler func(message WSMessage)
// WithBaseURL sets the base URL for the client
func WithBaseURL(url string) ClientOption {
return func(c *Client) {
c.baseURL = url
}
}
func (c *Client) OnConnect(callback func() error) {
c.onConnect = callback
}
// NewClient creates a new Olm client
func NewClient(olmID, secret string, endpoint string, opts ...ClientOption) (*Client, error) {
config := &Config{
OlmID: olmID,
Secret: secret,
Endpoint: endpoint,
}
client := &Client{
config: config,
baseURL: endpoint, // default value
handlers: make(map[string]MessageHandler),
done: make(chan struct{}),
reconnectInterval: 10 * time.Second,
isConnected: false,
}
// Apply options before loading config
for _, opt := range opts {
opt(client)
}
// Load existing config if available
if err := client.loadConfig(); err != nil {
return nil, fmt.Errorf("failed to load config: %w", err)
}
return client, nil
}
// Connect establishes the WebSocket connection
func (c *Client) Connect() error {
go c.connectWithRetry()
return nil
}
// Close closes the WebSocket connection
func (c *Client) Close() error {
close(c.done)
if c.conn != nil {
return c.conn.Close()
}
// stop the ping monitor
c.setConnected(false)
return nil
}
// SendMessage sends a message through the WebSocket connection
func (c *Client) SendMessage(messageType string, data interface{}) error {
if c.conn == nil {
return fmt.Errorf("not connected")
}
msg := WSMessage{
Type: messageType,
Data: data,
}
return c.conn.WriteJSON(msg)
}
// RegisterHandler registers a handler for a specific message type
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlersMux.Lock()
defer c.handlersMux.Unlock()
c.handlers[messageType] = handler
}
// readPump pumps messages from the WebSocket connection
func (c *Client) readPump() {
defer c.conn.Close()
for {
select {
case <-c.done:
return
default:
var msg WSMessage
err := c.conn.ReadJSON(&msg)
if err != nil {
return
}
c.handlersMux.RLock()
if handler, ok := c.handlers[msg.Type]; ok {
handler(msg)
}
c.handlersMux.RUnlock()
}
}
}
func (c *Client) getToken() (string, error) {
// Parse the base URL to ensure we have the correct hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return "", fmt.Errorf("failed to parse base URL: %w", err)
}
// Ensure we have the base URL without trailing slashes
baseEndpoint := strings.TrimRight(baseURL.String(), "/")
// If we already have a token, try to use it
if c.config.Token != "" {
tokenCheckData := map[string]interface{}{
"olmId": c.config.OlmID,
"secret": c.config.Secret,
"token": c.config.Token,
}
jsonData, err := json.Marshal(tokenCheckData)
if err != nil {
return "", fmt.Errorf("failed to marshal token check data: %w", err)
}
// Create a new request
req, err := http.NewRequest(
"POST",
baseEndpoint+"/api/v1/auth/olm/get-token",
bytes.NewBuffer(jsonData),
)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// Make the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to check token validity: %w", err)
}
defer resp.Body.Close()
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token check response: %w", err)
}
// If token is still valid, return it
if tokenResp.Success && tokenResp.Message == "Token session already valid" {
return c.config.Token, nil
}
}
// Get a new token
tokenData := map[string]interface{}{
"olmId": c.config.OlmID,
"secret": c.config.Secret,
}
jsonData, err := json.Marshal(tokenData)
if err != nil {
return "", fmt.Errorf("failed to marshal token request data: %w", err)
}
// Create a new request
req, err := http.NewRequest(
"POST",
baseEndpoint+"/api/v1/auth/olm/get-token",
bytes.NewBuffer(jsonData),
)
if err != nil {
return "", fmt.Errorf("failed to create request: %w", err)
}
// Set headers
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-CSRF-Token", "x-csrf-protection")
// Make the request
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to request new token: %w", err)
}
defer resp.Body.Close()
var tokenResp TokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
return "", fmt.Errorf("failed to decode token response: %w", err)
}
if !tokenResp.Success {
return "", fmt.Errorf("failed to get token: %s", tokenResp.Message)
}
if tokenResp.Data.Token == "" {
return "", fmt.Errorf("received empty token from server")
}
return tokenResp.Data.Token, nil
}
func (c *Client) connectWithRetry() {
for {
select {
case <-c.done:
return
default:
err := c.establishConnection()
if err != nil {
logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval)
time.Sleep(c.reconnectInterval)
continue
}
return
}
}
}
func (c *Client) establishConnection() error {
// Get token for authentication
token, err := c.getToken()
if err != nil {
return fmt.Errorf("failed to get token: %w", err)
}
// Parse the base URL to determine protocol and hostname
baseURL, err := url.Parse(c.baseURL)
if err != nil {
return fmt.Errorf("failed to parse base URL: %w", err)
}
// Determine WebSocket protocol based on HTTP protocol
wsProtocol := "wss"
if baseURL.Scheme == "http" {
wsProtocol = "ws"
}
// Create WebSocket URL
wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host)
u, err := url.Parse(wsURL)
if err != nil {
return fmt.Errorf("failed to parse WebSocket URL: %w", err)
}
// Add token to query parameters
q := u.Query()
q.Set("token", token)
u.RawQuery = q.Encode()
// Connect to WebSocket
conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil)
if err != nil {
return fmt.Errorf("failed to connect to WebSocket: %w", err)
}
c.conn = conn
c.setConnected(true)
// Start the ping monitor
go c.pingMonitor()
// Start the read pump
go c.readPump()
if c.onConnect != nil {
err := c.saveConfig()
if err != nil {
logger.Error("Failed to save config: %v", err)
}
if err := c.onConnect(); err != nil {
logger.Error("OnConnect callback failed: %v", err)
}
}
return nil
}
func (c *Client) pingMonitor() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-c.done:
return
case <-ticker.C:
if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil {
logger.Error("Ping failed: %v", err)
c.reconnect()
return
}
}
}
}
func (c *Client) reconnect() {
c.setConnected(false)
if c.conn != nil {
c.conn.Close()
}
go c.connectWithRetry()
}
func (c *Client) setConnected(status bool) {
c.reconnectMux.Lock()
defer c.reconnectMux.Unlock()
c.isConnected = status
}

72
websocket/config.go Normal file
View File

@@ -0,0 +1,72 @@
package websocket
import (
"encoding/json"
"log"
"os"
"path/filepath"
"runtime"
)
func getConfigPath() string {
var configDir string
switch runtime.GOOS {
case "darwin":
configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client")
case "windows":
configDir = filepath.Join(os.Getenv("APPDATA"), "olm-client")
default: // linux and others
configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client")
}
if err := os.MkdirAll(configDir, 0755); err != nil {
log.Printf("Failed to create config directory: %v", err)
}
return filepath.Join(configDir, "config.json")
}
func (c *Client) loadConfig() error {
if c.config.OlmID != "" && c.config.Secret != "" && c.config.Endpoint != "" {
return nil
}
configPath := getConfigPath()
data, err := os.ReadFile(configPath)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
var config Config
if err := json.Unmarshal(data, &config); err != nil {
return err
}
if c.config.OlmID == "" {
c.config.OlmID = config.OlmID
}
if c.config.Token == "" {
c.config.Token = config.Token
}
if c.config.Secret == "" {
c.config.Secret = config.Secret
}
if c.config.Endpoint == "" {
c.config.Endpoint = config.Endpoint
c.baseURL = config.Endpoint
}
return nil
}
func (c *Client) saveConfig() error {
configPath := getConfigPath()
data, err := json.MarshalIndent(c.config, "", " ")
if err != nil {
return err
}
return os.WriteFile(configPath, data, 0644)
}

21
websocket/types.go Normal file
View File

@@ -0,0 +1,21 @@
package websocket
type Config struct {
OlmID string `json:"olmId"`
Secret string `json:"secret"`
Token string `json:"token"`
Endpoint string `json:"endpoint"`
}
type TokenResponse struct {
Data struct {
Token string `json:"token"`
} `json:"data"`
Success bool `json:"success"`
Message string `json:"message"`
}
type WSMessage struct {
Type string `json:"type"`
Data interface{} `json:"data"`
}