diff --git a/README.md b/README.md index 848f302..ba7a29a 100644 --- a/README.md +++ b/README.md @@ -1,20 +1,14 @@ # Olm -Olm is a fully user space [WireGuard](https://www.wireguard.com/) tunnel olm and TCP/UDP proxy, designed to securely expose private resources controlled by Pangolin. By using Olm, you don't need to manage complex WireGuard tunnels and NATing. +Olm is a [WireGuard](https://www.wireguard.com/) tunnel manager designed to securely connect to private resources. By using Olm, you don't need to manage complex WireGuard tunnels. ### Installation and Documentation -Olm is used with Pangolin and Gerbil as part of the larger system. See documentation below: +Olm is used with Pangolin and Newt as part of the larger system. See documentation below: - [Installation Instructions](https://docs.fossorial.io) - [Full Documentation](https://docs.fossorial.io) -## Preview - -Preview - -_Sample output of a Olm container connected to Pangolin and hosting various resource target proxies._ - ## Key Functions ### Registers with Pangolin diff --git a/main.go b/main.go index 4076739..684754d 100644 --- a/main.go +++ b/main.go @@ -16,7 +16,7 @@ import ( "time" "github.com/fosrl/newt/logger" - "github.com/fosrl/newt/websocket" + "github.com/fosrl/olm/websocket" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" @@ -356,7 +356,7 @@ func main() { } // Create a new olm - olm, err := websocket.NewOlm( + olm, err := websocket.NewClient( id, // CLI arg takes precedence secret, // CLI arg takes precedence endpoint, diff --git a/public/screenshots/preview.png b/public/screenshots/preview.png deleted file mode 100644 index c6a8cd8..0000000 Binary files a/public/screenshots/preview.png and /dev/null differ diff --git a/websocket/client.go b/websocket/client.go new file mode 100644 index 0000000..8339f88 --- /dev/null +++ b/websocket/client.go @@ -0,0 +1,351 @@ +package websocket + +import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/fosrl/newt/logger" + + "github.com/gorilla/websocket" +) + +type Client struct { + conn *websocket.Conn + config *Config + baseURL string + handlers map[string]MessageHandler + done chan struct{} + handlersMux sync.RWMutex + + reconnectInterval time.Duration + isConnected bool + reconnectMux sync.RWMutex + + onConnect func() error +} + +type ClientOption func(*Client) + +type MessageHandler func(message WSMessage) + +// WithBaseURL sets the base URL for the client +func WithBaseURL(url string) ClientOption { + return func(c *Client) { + c.baseURL = url + } +} + +func (c *Client) OnConnect(callback func() error) { + c.onConnect = callback +} + +// NewClient creates a new Olm client +func NewClient(olmID, secret string, endpoint string, opts ...ClientOption) (*Client, error) { + config := &Config{ + OlmID: olmID, + Secret: secret, + Endpoint: endpoint, + } + + client := &Client{ + config: config, + baseURL: endpoint, // default value + handlers: make(map[string]MessageHandler), + done: make(chan struct{}), + reconnectInterval: 10 * time.Second, + isConnected: false, + } + + // Apply options before loading config + for _, opt := range opts { + opt(client) + } + + // Load existing config if available + if err := client.loadConfig(); err != nil { + return nil, fmt.Errorf("failed to load config: %w", err) + } + + return client, nil +} + +// Connect establishes the WebSocket connection +func (c *Client) Connect() error { + go c.connectWithRetry() + return nil +} + +// Close closes the WebSocket connection +func (c *Client) Close() error { + close(c.done) + if c.conn != nil { + return c.conn.Close() + } + + // stop the ping monitor + c.setConnected(false) + + return nil +} + +// SendMessage sends a message through the WebSocket connection +func (c *Client) SendMessage(messageType string, data interface{}) error { + if c.conn == nil { + return fmt.Errorf("not connected") + } + + msg := WSMessage{ + Type: messageType, + Data: data, + } + + return c.conn.WriteJSON(msg) +} + +// RegisterHandler registers a handler for a specific message type +func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { + c.handlersMux.Lock() + defer c.handlersMux.Unlock() + c.handlers[messageType] = handler +} + +// readPump pumps messages from the WebSocket connection +func (c *Client) readPump() { + defer c.conn.Close() + + for { + select { + case <-c.done: + return + default: + var msg WSMessage + err := c.conn.ReadJSON(&msg) + if err != nil { + return + } + + c.handlersMux.RLock() + if handler, ok := c.handlers[msg.Type]; ok { + handler(msg) + } + c.handlersMux.RUnlock() + } + } +} + +func (c *Client) getToken() (string, error) { + // Parse the base URL to ensure we have the correct hostname + baseURL, err := url.Parse(c.baseURL) + if err != nil { + return "", fmt.Errorf("failed to parse base URL: %w", err) + } + + // Ensure we have the base URL without trailing slashes + baseEndpoint := strings.TrimRight(baseURL.String(), "/") + + // If we already have a token, try to use it + if c.config.Token != "" { + tokenCheckData := map[string]interface{}{ + "olmId": c.config.OlmID, + "secret": c.config.Secret, + "token": c.config.Token, + } + jsonData, err := json.Marshal(tokenCheckData) + if err != nil { + return "", fmt.Errorf("failed to marshal token check data: %w", err) + } + + // Create a new request + req, err := http.NewRequest( + "POST", + baseEndpoint+"/api/v1/auth/olm/get-token", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-CSRF-Token", "x-csrf-protection") + + // Make the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to check token validity: %w", err) + } + defer resp.Body.Close() + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", fmt.Errorf("failed to decode token check response: %w", err) + } + + // If token is still valid, return it + if tokenResp.Success && tokenResp.Message == "Token session already valid" { + return c.config.Token, nil + } + } + + // Get a new token + tokenData := map[string]interface{}{ + "olmId": c.config.OlmID, + "secret": c.config.Secret, + } + jsonData, err := json.Marshal(tokenData) + if err != nil { + return "", fmt.Errorf("failed to marshal token request data: %w", err) + } + + // Create a new request + req, err := http.NewRequest( + "POST", + baseEndpoint+"/api/v1/auth/olm/get-token", + bytes.NewBuffer(jsonData), + ) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + // Set headers + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-CSRF-Token", "x-csrf-protection") + + // Make the request + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to request new token: %w", err) + } + defer resp.Body.Close() + + var tokenResp TokenResponse + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", fmt.Errorf("failed to decode token response: %w", err) + } + + if !tokenResp.Success { + return "", fmt.Errorf("failed to get token: %s", tokenResp.Message) + } + + if tokenResp.Data.Token == "" { + return "", fmt.Errorf("received empty token from server") + } + + return tokenResp.Data.Token, nil +} + +func (c *Client) connectWithRetry() { + for { + select { + case <-c.done: + return + default: + err := c.establishConnection() + if err != nil { + logger.Error("Failed to connect: %v. Retrying in %v...", err, c.reconnectInterval) + time.Sleep(c.reconnectInterval) + continue + } + return + } + } +} + +func (c *Client) establishConnection() error { + // Get token for authentication + token, err := c.getToken() + if err != nil { + return fmt.Errorf("failed to get token: %w", err) + } + + // Parse the base URL to determine protocol and hostname + baseURL, err := url.Parse(c.baseURL) + if err != nil { + return fmt.Errorf("failed to parse base URL: %w", err) + } + + // Determine WebSocket protocol based on HTTP protocol + wsProtocol := "wss" + if baseURL.Scheme == "http" { + wsProtocol = "ws" + } + + // Create WebSocket URL + wsURL := fmt.Sprintf("%s://%s/api/v1/ws", wsProtocol, baseURL.Host) + u, err := url.Parse(wsURL) + if err != nil { + return fmt.Errorf("failed to parse WebSocket URL: %w", err) + } + + // Add token to query parameters + q := u.Query() + q.Set("token", token) + u.RawQuery = q.Encode() + + // Connect to WebSocket + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + return fmt.Errorf("failed to connect to WebSocket: %w", err) + } + + c.conn = conn + c.setConnected(true) + + // Start the ping monitor + go c.pingMonitor() + // Start the read pump + go c.readPump() + + if c.onConnect != nil { + err := c.saveConfig() + if err != nil { + logger.Error("Failed to save config: %v", err) + } + if err := c.onConnect(); err != nil { + logger.Error("OnConnect callback failed: %v", err) + } + } + + return nil +} + +func (c *Client) pingMonitor() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.done: + return + case <-ticker.C: + if err := c.conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(10*time.Second)); err != nil { + logger.Error("Ping failed: %v", err) + c.reconnect() + return + } + } + } +} + +func (c *Client) reconnect() { + c.setConnected(false) + if c.conn != nil { + c.conn.Close() + } + + go c.connectWithRetry() +} + +func (c *Client) setConnected(status bool) { + c.reconnectMux.Lock() + defer c.reconnectMux.Unlock() + c.isConnected = status +} diff --git a/websocket/config.go b/websocket/config.go new file mode 100644 index 0000000..6e54042 --- /dev/null +++ b/websocket/config.go @@ -0,0 +1,72 @@ +package websocket + +import ( + "encoding/json" + "log" + "os" + "path/filepath" + "runtime" +) + +func getConfigPath() string { + var configDir string + switch runtime.GOOS { + case "darwin": + configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", "olm-client") + case "windows": + configDir = filepath.Join(os.Getenv("APPDATA"), "olm-client") + default: // linux and others + configDir = filepath.Join(os.Getenv("HOME"), ".config", "olm-client") + } + + if err := os.MkdirAll(configDir, 0755); err != nil { + log.Printf("Failed to create config directory: %v", err) + } + + return filepath.Join(configDir, "config.json") +} + +func (c *Client) loadConfig() error { + if c.config.OlmID != "" && c.config.Secret != "" && c.config.Endpoint != "" { + return nil + } + + configPath := getConfigPath() + data, err := os.ReadFile(configPath) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return err + } + + var config Config + if err := json.Unmarshal(data, &config); err != nil { + return err + } + + if c.config.OlmID == "" { + c.config.OlmID = config.OlmID + } + if c.config.Token == "" { + c.config.Token = config.Token + } + if c.config.Secret == "" { + c.config.Secret = config.Secret + } + if c.config.Endpoint == "" { + c.config.Endpoint = config.Endpoint + c.baseURL = config.Endpoint + } + + return nil +} + +func (c *Client) saveConfig() error { + configPath := getConfigPath() + data, err := json.MarshalIndent(c.config, "", " ") + if err != nil { + return err + } + return os.WriteFile(configPath, data, 0644) +} diff --git a/websocket/types.go b/websocket/types.go new file mode 100644 index 0000000..7786745 --- /dev/null +++ b/websocket/types.go @@ -0,0 +1,21 @@ +package websocket + +type Config struct { + OlmID string `json:"olmId"` + Secret string `json:"secret"` + Token string `json:"token"` + Endpoint string `json:"endpoint"` +} + +type TokenResponse struct { + Data struct { + Token string `json:"token"` + } `json:"data"` + Success bool `json:"success"` + Message string `json:"message"` +} + +type WSMessage struct { + Type string `json:"type"` + Data interface{} `json:"data"` +}