diff --git a/main.go b/main.go index e464f31..dbfe419 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,7 @@ import ( "net/netip" "os" "os/signal" + "path/filepath" "strconv" "strings" "syscall" @@ -48,6 +49,10 @@ type ExitNodeData struct { ExitNodes []ExitNode `json:"exitNodes"` } +type SSHPublicKeyData struct { + PublicKey string `json:"publicKey"` +} + // ExitNode represents an exit node with an ID, endpoint, and weight. type ExitNode struct { ID int `json:"exitNodeId"` @@ -93,6 +98,7 @@ var ( stopFunc func() healthFile string useNativeInterface bool + authorizedKeysFile string ) func main() { @@ -114,6 +120,8 @@ func main() { pingTimeoutStr := os.Getenv("PING_TIMEOUT") dockerEnforceNetworkValidation = os.Getenv("DOCKER_ENFORCE_NETWORK_VALIDATION") healthFile = os.Getenv("HEALTH_FILE") + useNativeInterface = os.Getenv("USE_NATIVE_INTERFACE") == "true" + authorizedKeysFile = os.Getenv("AUTHORIZED_KEYS_FILE") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -155,7 +163,13 @@ func main() { flag.StringVar(&pingIntervalStr, "ping-interval", "3s", "Interval for pinging the server (default 3s)") } if pingTimeoutStr == "" { - flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)") + flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)") + } + if pingTimeoutStr == "" { + flag.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 5s)") + } + if authorizedKeysFile == "" { + flag.StringVar(&authorizedKeysFile, "authorized-keys-file", "~/.ssh/authorized_keys", "Path to authorized keys file (if unset, no keys will be authorized)") } if pingIntervalStr != "" { @@ -787,6 +801,93 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } }) + client.RegisterHandler("newt/send/ssh/publicKey", func(msg websocket.WSMessage) { + logger.Debug("Received SSH public key request") + + var sshPublicKeyData SSHPublicKeyData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + if err := json.Unmarshal(jsonData, &sshPublicKeyData); err != nil { + logger.Info("Error unmarshaling SSH public key data: %v", err) + return + } + + sshPublicKey := sshPublicKeyData.PublicKey + + if authorizedKeysFile == "" { + logger.Debug("No authorized keys file set, skipping public key response") + return + } + + // Expand tilde to home directory if present + expandedPath := authorizedKeysFile + if strings.HasPrefix(authorizedKeysFile, "~/") { + homeDir, err := os.UserHomeDir() + if err != nil { + logger.Error("Failed to get user home directory: %v", err) + return + } + expandedPath = filepath.Join(homeDir, authorizedKeysFile[2:]) + } + + // if it is set but the file does not exist, create it + if _, err := os.Stat(expandedPath); os.IsNotExist(err) { + logger.Debug("Authorized keys file does not exist, creating it: %s", expandedPath) + if err := os.MkdirAll(filepath.Dir(expandedPath), 0755); err != nil { + logger.Error("Failed to create directory for authorized keys file: %v", err) + return + } + if _, err := os.Create(expandedPath); err != nil { + logger.Error("Failed to create authorized keys file: %v", err) + return + } + } + + // Check if the public key already exists in the file + fileContent, err := os.ReadFile(expandedPath) + if err != nil { + logger.Error("Failed to read authorized keys file: %v", err) + return + } + + // Check if the key already exists (trim whitespace for comparison) + existingKeys := strings.Split(string(fileContent), "\n") + keyAlreadyExists := false + trimmedNewKey := strings.TrimSpace(sshPublicKey) + + for _, existingKey := range existingKeys { + if strings.TrimSpace(existingKey) == trimmedNewKey && trimmedNewKey != "" { + keyAlreadyExists = true + break + } + } + + if keyAlreadyExists { + logger.Info("SSH public key already exists in authorized keys file, skipping") + return + } + + // append the public key to the authorized keys file + logger.Debug("Appending public key to authorized keys file: %s", sshPublicKey) + file, err := os.OpenFile(expandedPath, os.O_APPEND|os.O_WRONLY, 0644) + if err != nil { + logger.Error("Failed to open authorized keys file: %v", err) + return + } + defer file.Close() + + if _, err := file.WriteString(sshPublicKey + "\n"); err != nil { + logger.Error("Failed to write public key to authorized keys file: %v", err) + return + } + + logger.Info("SSH public key appended to authorized keys file") + }) + client.OnConnect(func() error { publicKey = privateKey.PublicKey() logger.Debug("Public key: %s", publicKey) diff --git a/websocket/config.go b/websocket/config.go index 4b1e5e0..6803e81 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -9,22 +9,27 @@ import ( ) func getConfigPath(clientType string) string { - var configDir string - switch runtime.GOOS { - case "darwin": - configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", clientType+"-client") - case "windows": - logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm") - configDir = filepath.Join(logDir, clientType+"-client") - default: // linux and others - configDir = filepath.Join(os.Getenv("HOME"), ".config", clientType+"-client") + configFile := os.Getenv("CONFIG_FILE") + if configFile == "" { + var configDir string + switch runtime.GOOS { + case "darwin": + configDir = filepath.Join(os.Getenv("HOME"), "Library", "Application Support", clientType+"-client") + case "windows": + logDir := filepath.Join(os.Getenv("PROGRAMDATA"), "olm") + configDir = filepath.Join(logDir, clientType+"-client") + default: // linux and others + configDir = filepath.Join(os.Getenv("HOME"), ".config", clientType+"-client") + } + + if err := os.MkdirAll(configDir, 0755); err != nil { + log.Printf("Failed to create config directory: %v", err) + } + + return filepath.Join(configDir, "config.json") } - if err := os.MkdirAll(configDir, 0755); err != nil { - log.Printf("Failed to create config directory: %v", err) - } - - return filepath.Join(configDir, "config.json") + return configFile } func (c *Client) loadConfig() error {