mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
Refactor ssh server and client
This commit is contained in:
@@ -3,9 +3,11 @@ package cmd
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
@@ -17,43 +19,34 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
user = "root"
|
||||
host string
|
||||
port int
|
||||
username string
|
||||
host string
|
||||
command string
|
||||
)
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [user@]host",
|
||||
Args: func(cmd *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("requires a host argument")
|
||||
}
|
||||
Use: "ssh [user@]host [command]",
|
||||
Short: "Connect to a NetBird peer via SSH",
|
||||
Long: `Connect to a NetBird peer using SSH.
|
||||
|
||||
split := strings.Split(args[0], "@")
|
||||
if len(split) == 2 {
|
||||
user = split[0]
|
||||
host = split[1]
|
||||
} else {
|
||||
host = args[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
Short: "connect to a remote SSH server",
|
||||
Examples:
|
||||
netbird ssh peer-hostname
|
||||
netbird ssh user@peer-hostname
|
||||
netbird ssh peer-hostname --login myuser
|
||||
netbird ssh peer-hostname -p 22022
|
||||
netbird ssh peer-hostname ls -la
|
||||
netbird ssh peer-hostname whoami`,
|
||||
DisableFlagParsing: true,
|
||||
Args: validateSSHArgsWithoutFlagParsing,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
err := util.InitLog(logLevel, "console")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed initializing log %v", err)
|
||||
}
|
||||
|
||||
if !util.IsAdmin() {
|
||||
cmd.Printf("error: you must have Administrator privileges to run this command\n")
|
||||
return nil
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
@@ -62,7 +55,7 @@ var sshCmd = &cobra.Command{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("update config: %w", err)
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
@@ -70,7 +63,6 @@ var sshCmd = &cobra.Command{
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
// blocking
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
@@ -88,31 +80,124 @@ var sshCmd = &cobra.Command{
|
||||
},
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("host argument required")
|
||||
}
|
||||
|
||||
// Reset globals to defaults
|
||||
port = nbssh.DefaultSSHPort
|
||||
username = ""
|
||||
host = ""
|
||||
command = ""
|
||||
|
||||
// Create a new FlagSet for parsing SSH-specific flags
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil) // Suppress error output
|
||||
|
||||
// Define SSH-specific flags
|
||||
portFlag := fs.Int("p", nbssh.DefaultSSHPort, "SSH port")
|
||||
fs.Int("port", nbssh.DefaultSSHPort, "SSH port")
|
||||
userFlag := fs.String("u", "", "SSH username")
|
||||
fs.String("user", "", "SSH username")
|
||||
loginFlag := fs.String("login", "", "SSH username (alias for --user)")
|
||||
|
||||
// Parse flags until we hit the hostname (first non-flag argument)
|
||||
err := fs.Parse(args)
|
||||
if err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
|
||||
"\nYou can verify the connection by running:\n\n" +
|
||||
" netbird status\n\n")
|
||||
return err
|
||||
// If flag parsing fails, treat everything as hostname + command
|
||||
// This handles cases like `ssh hostname ls -la` where `-la` should be part of the command
|
||||
return parseHostnameAndCommand(args)
|
||||
}
|
||||
|
||||
// Get the remaining args (hostname and command)
|
||||
remaining := fs.Args()
|
||||
if len(remaining) < 1 {
|
||||
return errors.New("host argument required")
|
||||
}
|
||||
|
||||
// Set parsed values
|
||||
port = *portFlag
|
||||
if *userFlag != "" {
|
||||
username = *userFlag
|
||||
} else if *loginFlag != "" {
|
||||
username = *loginFlag
|
||||
}
|
||||
|
||||
return parseHostnameAndCommand(remaining)
|
||||
}
|
||||
|
||||
func parseHostnameAndCommand(args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("host argument required")
|
||||
}
|
||||
|
||||
// Parse hostname (possibly with user@host format)
|
||||
arg := args[0]
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
// Only use username from host if not already set by flags
|
||||
if username == "" {
|
||||
username = parts[0]
|
||||
}
|
||||
host = parts[1]
|
||||
} else {
|
||||
host = arg
|
||||
}
|
||||
|
||||
// Set default username if none provided
|
||||
if username == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
username = sudoUser
|
||||
} else if currentUser, err := user.Current(); err == nil {
|
||||
username = currentUser.Username
|
||||
} else {
|
||||
username = "root"
|
||||
}
|
||||
}
|
||||
|
||||
// Everything after hostname becomes the command
|
||||
if len(args) > 1 {
|
||||
command = strings.Join(args[1:], " ")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
c, err := nbssh.DialWithKey(ctx, target, username, pemKey)
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||
cmd.Printf("\nTroubleshooting steps:\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
|
||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n\n")
|
||||
return fmt.Errorf("dial %s: %w", target, err)
|
||||
}
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
err = c.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = c.Close()
|
||||
}()
|
||||
|
||||
err = c.OpenTerminal()
|
||||
if err != nil {
|
||||
return err
|
||||
if command != "" {
|
||||
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := c.OpenTerminal(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", "SSH username")
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", "SSH username (alias for --user)")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user