Refactor ssh server and client

This commit is contained in:
Viktor Liu
2025-06-18 20:49:06 +02:00
parent 520f2cfdb4
commit 6ed846ae29
19 changed files with 3532 additions and 554 deletions

View File

@@ -3,9 +3,11 @@ package cmd
import (
"context"
"errors"
"flag"
"fmt"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
@@ -17,43 +19,34 @@ import (
)
var (
port int
user = "root"
host string
port int
username string
host string
command string
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
}
Use: "ssh [user@]host [command]",
Short: "Connect to a NetBird peer via SSH",
Long: `Connect to a NetBird peer using SSH.
split := strings.Split(args[0], "@")
if len(split) == 2 {
user = split[0]
host = split[1]
} else {
host = args[0]
}
return nil
},
Short: "connect to a remote SSH server",
Examples:
netbird ssh peer-hostname
netbird ssh user@peer-hostname
netbird ssh peer-hostname --login myuser
netbird ssh peer-hostname -p 22022
netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami`,
DisableFlagParsing: true,
Args: validateSSHArgsWithoutFlagParsing,
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
if !util.IsAdmin() {
cmd.Printf("error: you must have Administrator privileges to run this command\n")
return nil
if err := util.InitLog(logLevel, "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
@@ -62,7 +55,7 @@ var sshCmd = &cobra.Command{
ConfigPath: configPath,
})
if err != nil {
return err
return fmt.Errorf("update config: %w", err)
}
sig := make(chan os.Signal, 1)
@@ -70,7 +63,6 @@ var sshCmd = &cobra.Command{
sshctx, cancel := context.WithCancel(ctx)
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err)
os.Exit(1)
@@ -88,31 +80,124 @@ var sshCmd = &cobra.Command{
},
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("host argument required")
}
// Reset globals to defaults
port = nbssh.DefaultSSHPort
username = ""
host = ""
command = ""
// Create a new FlagSet for parsing SSH-specific flags
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil) // Suppress error output
// Define SSH-specific flags
portFlag := fs.Int("p", nbssh.DefaultSSHPort, "SSH port")
fs.Int("port", nbssh.DefaultSSHPort, "SSH port")
userFlag := fs.String("u", "", "SSH username")
fs.String("user", "", "SSH username")
loginFlag := fs.String("login", "", "SSH username (alias for --user)")
// Parse flags until we hit the hostname (first non-flag argument)
err := fs.Parse(args)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
// If flag parsing fails, treat everything as hostname + command
// This handles cases like `ssh hostname ls -la` where `-la` should be part of the command
return parseHostnameAndCommand(args)
}
// Get the remaining args (hostname and command)
remaining := fs.Args()
if len(remaining) < 1 {
return errors.New("host argument required")
}
// Set parsed values
port = *portFlag
if *userFlag != "" {
username = *userFlag
} else if *loginFlag != "" {
username = *loginFlag
}
return parseHostnameAndCommand(remaining)
}
func parseHostnameAndCommand(args []string) error {
if len(args) < 1 {
return errors.New("host argument required")
}
// Parse hostname (possibly with user@host format)
arg := args[0]
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
// Only use username from host if not already set by flags
if username == "" {
username = parts[0]
}
host = parts[1]
} else {
host = arg
}
// Set default username if none provided
if username == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
username = sudoUser
} else if currentUser, err := user.Current(); err == nil {
username = currentUser.Username
} else {
username = "root"
}
}
// Everything after hostname becomes the command
if len(args) > 1 {
command = strings.Join(args[1:], " ")
}
return nil
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
c, err := nbssh.DialWithKey(ctx, target, username, pemKey)
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n\n")
return fmt.Errorf("dial %s: %w", target, err)
}
go func() {
<-ctx.Done()
err = c.Close()
if err != nil {
return
}
_ = c.Close()
}()
err = c.OpenTerminal()
if err != nil {
return err
if command != "" {
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
return err
}
} else {
if err := c.OpenTerminal(ctx); err != nil {
return err
}
}
return nil
}
func init() {
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", "SSH username")
sshCmd.PersistentFlags().StringVar(&username, "login", "", "SSH username (alias for --user)")
}