diff --git a/main.go b/main.go index 3ef705c..339ea2f 100644 --- a/main.go +++ b/main.go @@ -15,9 +15,9 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/updates" - "github.com/fosrl/newt/websocket" "github.com/fosrl/olm/httpserver" "github.com/fosrl/olm/peermonitor" + "github.com/fosrl/olm/websocket" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" diff --git a/websocket/client.go b/websocket/client.go new file mode 100644 index 0000000..d1ab3da --- /dev/null +++ b/websocket/client.go @@ -0,0 +1,637 @@ +package websocket + +import ( + "bytes" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" + + "software.sslmate.com/src/go-pkcs12" + + "github.com/fosrl/newt/logger" + "github.com/gorilla/websocket" +) + +type TokenResponse struct { + Data struct { + Token string `json:"token"` + } `json:"data"` + Success bool `json:"success"` + Message string `json:"message"` +} + +type WSMessage struct { + Type string `json:"type"` + Data interface{} `json:"data"` +} + +// this is not json anymore +type Config struct { + ID string + Secret string + Endpoint string + TlsClientCert string // legacy PKCS12 file path +} + +type Client struct { + config *Config + conn *websocket.Conn + baseURL string + handlers map[string]MessageHandler + done chan struct{} + handlersMux sync.RWMutex + reconnectInterval time.Duration + isConnected bool + reconnectMux sync.RWMutex + pingInterval time.Duration + pingTimeout time.Duration + onConnect func() error + onTokenUpdate func(token string) + writeMux sync.Mutex + clientType string // Type of client (e.g., "newt", "olm") + tlsConfig TLSConfig + configNeedsSave bool // Flag to track if config needs to be saved +} + +type ClientOption func(*Client) + +type MessageHandler func(message WSMessage) + +// TLSConfig holds TLS configuration options +type TLSConfig struct { + // New separate certificate support + ClientCertFile string + ClientKeyFile string + CAFiles []string + + // Existing PKCS12 support (deprecated) + PKCS12File string +} + +// WithBaseURL sets the base URL for the client +func WithBaseURL(url string) ClientOption { + return func(c *Client) { + c.baseURL = url + } +} + +// WithTLSConfig sets the TLS configuration for the client +func WithTLSConfig(config TLSConfig) ClientOption { + return func(c *Client) { + c.tlsConfig = config + // For backward compatibility, also set the legacy field + if config.PKCS12File != "" { + c.config.TlsClientCert = config.PKCS12File + } + } +} + +func (c *Client) OnConnect(callback func() error) { + c.onConnect = callback +} + +func (c *Client) OnTokenUpdate(callback func(token string)) { + c.onTokenUpdate = callback +} + +// NewClient creates a new websocket client +func NewClient(clientType string, ID, secret string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { + config := &Config{ + ID: ID, + Secret: secret, + Endpoint: endpoint, + } + + client := &Client{ + config: config, + baseURL: endpoint, // default value + handlers: make(map[string]MessageHandler), + done: make(chan struct{}), + reconnectInterval: 3 * time.Second, + isConnected: false, + pingInterval: pingInterval, + pingTimeout: pingTimeout, + clientType: clientType, + } + + // Apply options before loading config + for _, opt := range opts { + if opt == nil { + continue + } + opt(client) + } + + return client, nil +} + +func (c *Client) GetConfig() *Config { + return c.config +} + +// Connect establishes the WebSocket connection +func (c *Client) Connect() error { + go c.connectWithRetry() + return nil +} + +// Close closes the WebSocket connection gracefully +func (c *Client) Close() error { + // Signal shutdown to all goroutines first + select { + case <-c.done: + // Already closed + return nil + default: + close(c.done) + } + + // Set connection status to false + c.setConnected(false) + + // Close the WebSocket connection gracefully + if c.conn != nil { + // Send close message + c.writeMux.Lock() + c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + c.writeMux.Unlock() + + // Close the connection + return c.conn.Close() + } + + return nil +} + +// SendMessage sends a message through the WebSocket connection +func (c *Client) SendMessage(messageType string, data interface{}) error { + if c.conn == nil { + return fmt.Errorf("not connected") + } + + msg := WSMessage{ + Type: messageType, + Data: data, + } + + logger.Debug("Sending message: %s, data: %+v", messageType, data) + + c.writeMux.Lock() + defer c.writeMux.Unlock() + return c.conn.WriteJSON(msg) +} + +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { + stopChan := make(chan struct{}) + go func() { + count := 0 + maxAttempts := 10 + + err := c.SendMessage(messageType, data) // Send immediately + if err != nil { + logger.Error("Failed to send initial message: %v", err) + } + count++ + + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if count >= maxAttempts { + logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) + return + } + err = c.SendMessage(messageType, data) + if err != nil { + logger.Error("Failed to send message: %v", err) + } + count++ + case <-stopChan: + return + } + } + }() + return func() { + close(stopChan) + } +} + +// RegisterHandler registers a handler for a specific message type +func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { + c.handlersMux.Lock() + defer c.handlersMux.Unlock() + c.handlers[messageType] = handler +} + +func (c *Client) getToken() (string, error) { + // Parse the base URL to ensure we have the correct hostname + baseURL, err := url.Parse(c.baseURL) + if err != nil { + return "", fmt.Errorf("failed to parse base URL: %w", err) + } + + // Ensure we have the base URL without trailing slashes + baseEndpoint := strings.TrimRight(baseURL.String(), "/") + + var tlsConfig *tls.Config = nil + + // Use new TLS configuration method + if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { + tlsConfig, err = c.setupTLS() + if err != nil { + return "", fmt.Errorf("failed to setup TLS configuration: %w", err) + } + } + + // Check for environment variable to skip TLS verification + if os.Getenv("SKIP_TLS_VERIFY") == "true" { + if tlsConfig == nil { + tlsConfig = &tls.Config{} + } + tlsConfig.InsecureSkipVerify = true + logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + } + + var tokenData map[string]interface{} + + // Get a new token + if c.clientType == "newt" { + tokenData = map[string]interface{}{ + "newtId": c.config.ID, + "secret": c.config.Secret, + } + } else if c.clientType == "olm" { + tokenData = map[string]interface{}{ + "olmId": c.config.ID, + "secret": c.config.Secret, + } + } + jsonData, err := json.Marshal(tokenData) + + if err != nil { + return "", fmt.Errorf("failed to marshal token request data: %w", err) + } + + // Create a new request + req, err := http.NewRequest( + "POST", + baseEndpoint+"/api/v1/auth/"+c.clientType+"/get-token", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-CSRF-Token", "x-csrf-protection") + + // Make the request + client := &http.Client{} + if tlsConfig != nil { + client.Transport = &http.Transport{ + TLSClientConfig: tlsConfig, + } + } + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to request new token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + logger.Error("Failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + } + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + logger.Error("Failed to decode token response.") + return "", fmt.Errorf("failed to decode token response: %w", err) + } + + if !tokenResp.Success { + return "", fmt.Errorf("failed to get token: %s", tokenResp.Message) + } + + if tokenResp.Data.Token == "" { + return "", fmt.Errorf("received empty token from server") + } + + logger.Debug("Received token: %s", tokenResp.Data.Token) + + return tokenResp.Data.Token, nil +} + +func (c *Client) connectWithRetry() { + for { + select { + case <-c.done: + return + default: + err := c.establishConnection() + if err != nil { + logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + time.Sleep(c.reconnectInterval) + continue + } + return + } + } +} + +func (c *Client) establishConnection() error { + // Get token for authentication + token, err := c.getToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + if c.onTokenUpdate != nil { + c.onTokenUpdate(token) + } + + // Parse the base URL to determine protocol and hostname + baseURL, err := url.Parse(c.baseURL) + if err != nil { + return fmt.Errorf("failed to parse base URL: %w", err) + } + + // Determine WebSocket protocol based on HTTP protocol + wsProtocol := "wss" + if baseURL.Scheme == "http" { + wsProtocol = "ws" + } + + // Create WebSocket URL + wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host) + u, err := url.Parse(wsURL) + if err != nil { + return fmt.Errorf("failed to parse WebSocket URL: %w", err) + } + + // Add token to query parameters + q := u.Query() + q.Set("token", token) + q.Set("clientType", c.clientType) + u.RawQuery = q.Encode() + + // Connect to WebSocket + dialer := websocket.DefaultDialer + + // Use new TLS configuration method + if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { + logger.Info("Setting up TLS configuration for WebSocket connection") + tlsConfig, err := c.setupTLS() + if err != nil { + return fmt.Errorf("failed to setup TLS configuration: %w", err) + } + dialer.TLSClientConfig = tlsConfig + } + + // Check for environment variable to skip TLS verification for WebSocket connection + if os.Getenv("SKIP_TLS_VERIFY") == "true" { + if dialer.TLSClientConfig == nil { + dialer.TLSClientConfig = &tls.Config{} + } + dialer.TLSClientConfig.InsecureSkipVerify = true + logger.Debug("WebSocket TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") + } + + conn, _, err := dialer.Dial(u.String(), nil) + if err != nil { + return fmt.Errorf("failed to connect to WebSocket: %w", err) + } + + c.conn = conn + c.setConnected(true) + + // Start the ping monitor + go c.pingMonitor() + // Start the read pump with disconnect detection + go c.readPumpWithDisconnectDetection() + + if c.onConnect != nil { + if err := c.onConnect(); err != nil { + logger.Error("OnConnect callback failed: %v", err) + } + } + + return nil +} + +// setupTLS configures TLS based on the TLS configuration +func (c *Client) setupTLS() (*tls.Config, error) { + tlsConfig := &tls.Config{} + + // Handle new separate certificate configuration + if c.tlsConfig.ClientCertFile != "" && c.tlsConfig.ClientKeyFile != "" { + logger.Info("Loading separate certificate files for mTLS") + logger.Debug("Client cert: %s", c.tlsConfig.ClientCertFile) + logger.Debug("Client key: %s", c.tlsConfig.ClientKeyFile) + + // Load client certificate and key + cert, err := tls.LoadX509KeyPair(c.tlsConfig.ClientCertFile, c.tlsConfig.ClientKeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load client certificate pair: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + + // Load CA certificates for remote validation if specified + if len(c.tlsConfig.CAFiles) > 0 { + logger.Debug("Loading CA certificates: %v", c.tlsConfig.CAFiles) + caCertPool := x509.NewCertPool() + for _, caFile := range c.tlsConfig.CAFiles { + caCert, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("failed to read CA file %s: %w", caFile, err) + } + + // Try to parse as PEM first, then DER + if !caCertPool.AppendCertsFromPEM(caCert) { + // If PEM parsing failed, try DER + cert, err := x509.ParseCertificate(caCert) + if err != nil { + return nil, fmt.Errorf("failed to parse CA certificate from %s: %w", caFile, err) + } + caCertPool.AddCert(cert) + } + } + tlsConfig.RootCAs = caCertPool + } + + return tlsConfig, nil + } + + // Fallback to existing PKCS12 implementation for backward compatibility + if c.tlsConfig.PKCS12File != "" { + logger.Info("Loading PKCS12 certificate for mTLS (deprecated)") + return c.setupPKCS12TLS() + } + + // Legacy fallback using config.TlsClientCert + if c.config.TlsClientCert != "" { + logger.Info("Loading legacy PKCS12 certificate for mTLS (deprecated)") + return loadClientCertificate(c.config.TlsClientCert) + } + + return nil, nil +} + +// setupPKCS12TLS loads TLS configuration from PKCS12 file +func (c *Client) setupPKCS12TLS() (*tls.Config, error) { + return loadClientCertificate(c.tlsConfig.PKCS12File) +} + +// pingMonitor sends pings at a short interval and triggers reconnect on failure +func (c *Client) pingMonitor() { + ticker := time.NewTicker(c.pingInterval) + defer ticker.Stop() + + for { + select { + case <-c.done: + return + case <-ticker.C: + if c.conn == nil { + return + } + c.writeMux.Lock() + err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(c.pingTimeout)) + c.writeMux.Unlock() + if err != nil { + // Check if we're shutting down before logging error and reconnecting + select { + case <-c.done: + // Expected during shutdown + return + default: + logger.Error("Ping failed: %v", err) + c.reconnect() + return + } + } + } + } +} + +// readPumpWithDisconnectDetection reads messages and triggers reconnect on error +func (c *Client) readPumpWithDisconnectDetection() { + defer func() { + if c.conn != nil { + c.conn.Close() + } + // Only attempt reconnect if we're not shutting down + select { + case <-c.done: + // Shutting down, don't reconnect + return + default: + c.reconnect() + } + }() + + for { + select { + case <-c.done: + return + default: + var msg WSMessage + err := c.conn.ReadJSON(&msg) + if err != nil { + // Check if we're shutting down before logging error + select { + case <-c.done: + // Expected during shutdown, don't log as error + logger.Debug("WebSocket connection closed during shutdown") + return + default: + // Unexpected error during normal operation + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure, websocket.CloseNormalClosure) { + logger.Error("WebSocket read error: %v", err) + } else { + logger.Debug("WebSocket connection closed: %v", err) + } + return // triggers reconnect via defer + } + } + + c.handlersMux.RLock() + if handler, ok := c.handlers[msg.Type]; ok { + handler(msg) + } + c.handlersMux.RUnlock() + } + } +} + +func (c *Client) reconnect() { + c.setConnected(false) + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + + // Only reconnect if we're not shutting down + select { + case <-c.done: + return + default: + go c.connectWithRetry() + } +} + +func (c *Client) setConnected(status bool) { + c.reconnectMux.Lock() + defer c.reconnectMux.Unlock() + c.isConnected = status +} + +// LoadClientCertificate Helper method to load client certificates (PKCS12 format) +func loadClientCertificate(p12Path string) (*tls.Config, error) { + logger.Info("Loading tls-client-cert %s", p12Path) + // Read the PKCS12 file + p12Data, err := os.ReadFile(p12Path) + if err != nil { + return nil, fmt.Errorf("failed to read PKCS12 file: %w", err) + } + + // Parse PKCS12 with empty password for non-encrypted files + privateKey, certificate, caCerts, err := pkcs12.DecodeChain(p12Data, "") + if err != nil { + return nil, fmt.Errorf("failed to decode PKCS12: %w", err) + } + + // Create certificate + cert := tls.Certificate{ + Certificate: [][]byte{certificate.Raw}, + PrivateKey: privateKey, + } + + // Optional: Add CA certificates if present + rootCAs, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("failed to load system cert pool: %w", err) + } + if len(caCerts) > 0 { + for _, caCert := range caCerts { + rootCAs.AddCert(caCert) + } + } + + // Create TLS configuration + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + RootCAs: rootCAs, + }, nil +}