Refactor ssh server and client

This commit is contained in:
Viktor Liu
2025-06-18 20:49:06 +02:00
parent 520f2cfdb4
commit 6ed846ae29
19 changed files with 3532 additions and 554 deletions

View File

@@ -3,9 +3,11 @@ package cmd
import (
"context"
"errors"
"flag"
"fmt"
"os"
"os/signal"
"os/user"
"strings"
"syscall"
@@ -17,43 +19,34 @@ import (
)
var (
port int
user = "root"
host string
port int
username string
host string
command string
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
}
Use: "ssh [user@]host [command]",
Short: "Connect to a NetBird peer via SSH",
Long: `Connect to a NetBird peer using SSH.
split := strings.Split(args[0], "@")
if len(split) == 2 {
user = split[0]
host = split[1]
} else {
host = args[0]
}
return nil
},
Short: "connect to a remote SSH server",
Examples:
netbird ssh peer-hostname
netbird ssh user@peer-hostname
netbird ssh peer-hostname --login myuser
netbird ssh peer-hostname -p 22022
netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami`,
DisableFlagParsing: true,
Args: validateSSHArgsWithoutFlagParsing,
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
err := util.InitLog(logLevel, "console")
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
if !util.IsAdmin() {
cmd.Printf("error: you must have Administrator privileges to run this command\n")
return nil
if err := util.InitLog(logLevel, "console"); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
@@ -62,7 +55,7 @@ var sshCmd = &cobra.Command{
ConfigPath: configPath,
})
if err != nil {
return err
return fmt.Errorf("update config: %w", err)
}
sig := make(chan os.Signal, 1)
@@ -70,7 +63,6 @@ var sshCmd = &cobra.Command{
sshctx, cancel := context.WithCancel(ctx)
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err)
os.Exit(1)
@@ -88,31 +80,124 @@ var sshCmd = &cobra.Command{
},
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), user, pemKey)
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("host argument required")
}
// Reset globals to defaults
port = nbssh.DefaultSSHPort
username = ""
host = ""
command = ""
// Create a new FlagSet for parsing SSH-specific flags
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil) // Suppress error output
// Define SSH-specific flags
portFlag := fs.Int("p", nbssh.DefaultSSHPort, "SSH port")
fs.Int("port", nbssh.DefaultSSHPort, "SSH port")
userFlag := fs.String("u", "", "SSH username")
fs.String("user", "", "SSH username")
loginFlag := fs.String("login", "", "SSH username (alias for --user)")
// Parse flags until we hit the hostname (first non-flag argument)
err := fs.Parse(args)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
// If flag parsing fails, treat everything as hostname + command
// This handles cases like `ssh hostname ls -la` where `-la` should be part of the command
return parseHostnameAndCommand(args)
}
// Get the remaining args (hostname and command)
remaining := fs.Args()
if len(remaining) < 1 {
return errors.New("host argument required")
}
// Set parsed values
port = *portFlag
if *userFlag != "" {
username = *userFlag
} else if *loginFlag != "" {
username = *loginFlag
}
return parseHostnameAndCommand(remaining)
}
func parseHostnameAndCommand(args []string) error {
if len(args) < 1 {
return errors.New("host argument required")
}
// Parse hostname (possibly with user@host format)
arg := args[0]
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
// Only use username from host if not already set by flags
if username == "" {
username = parts[0]
}
host = parts[1]
} else {
host = arg
}
// Set default username if none provided
if username == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
username = sudoUser
} else if currentUser, err := user.Current(); err == nil {
username = currentUser.Username
} else {
username = "root"
}
}
// Everything after hostname becomes the command
if len(args) > 1 {
command = strings.Join(args[1:], " ")
}
return nil
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
c, err := nbssh.DialWithKey(ctx, target, username, pemKey)
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n\n")
return fmt.Errorf("dial %s: %w", target, err)
}
go func() {
<-ctx.Done()
err = c.Close()
if err != nil {
return
}
_ = c.Close()
}()
err = c.OpenTerminal()
if err != nil {
return err
if command != "" {
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
return err
}
} else {
if err := c.OpenTerminal(ctx); err != nil {
return err
}
}
return nil
}
func init() {
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", "SSH username")
sshCmd.PersistentFlags().StringVar(&username, "login", "", "SSH username (alias for --user)")
}

342
client/cmd/ssh_test.go Normal file
View File

@@ -0,0 +1,342 @@
package cmd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSSHCommand_FlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedUser string
expectedPort int
expectedCmd string
expectError bool
}{
{
name: "basic host",
args: []string{"hostname"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22022,
expectedCmd: "",
},
{
name: "user@host format",
args: []string{"user@hostname"},
expectedHost: "hostname",
expectedUser: "user",
expectedPort: 22022,
expectedCmd: "",
},
{
name: "host with command",
args: []string{"hostname", "echo", "hello"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22022,
expectedCmd: "echo hello",
},
{
name: "command with flags should be preserved",
args: []string{"hostname", "ls", "-la", "/tmp"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22022,
expectedCmd: "ls -la /tmp",
},
{
name: "double dash separator",
args: []string{"hostname", "--", "ls", "-la"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22022,
expectedCmd: "-- ls -la",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22022
command = ""
// Mock command for testing
cmd := sshCmd
cmd.SetArgs(tt.args)
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expectedHost, host, "host mismatch")
if tt.expectedUser != "" {
assert.Equal(t, tt.expectedUser, username, "username mismatch")
}
assert.Equal(t, tt.expectedPort, port, "port mismatch")
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
})
}
}
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
// Test that SSH flags don't conflict with command flags
tests := []struct {
name string
args []string
expectedCmd string
description string
}{
{
name: "ls with -la flags",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
description: "ls flags should be passed to remote command",
},
{
name: "grep with -r flag",
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
expectedCmd: "grep -r pattern /path",
description: "grep flags should be passed to remote command",
},
{
name: "ps with aux flags",
args: []string{"hostname", "ps", "aux"},
expectedCmd: "ps aux",
description: "ps flags should be passed to remote command",
},
{
name: "command with double dash",
args: []string{"hostname", "--", "ls", "-la"},
expectedCmd: "-- ls -la",
description: "double dash should be preserved in command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22022
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err)
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
// Test that commands with arguments should execute the command and exit,
// not drop to an interactive shell
tests := []struct {
name string
args []string
expectedCmd string
shouldExit bool
description string
}{
{
name: "ls command should execute and exit",
args: []string{"hostname", "ls"},
expectedCmd: "ls",
shouldExit: true,
description: "ls command should execute and exit, not drop to shell",
},
{
name: "ls with flags should execute and exit",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
shouldExit: true,
description: "ls with flags should execute and exit, not drop to shell",
},
{
name: "pwd command should execute and exit",
args: []string{"hostname", "pwd"},
expectedCmd: "pwd",
shouldExit: true,
description: "pwd command should execute and exit, not drop to shell",
},
{
name: "echo command should execute and exit",
args: []string{"hostname", "echo", "hello"},
expectedCmd: "echo hello",
shouldExit: true,
description: "echo command should execute and exit, not drop to shell",
},
{
name: "no command should open shell",
args: []string{"hostname"},
expectedCmd: "",
shouldExit: false,
description: "no command should open interactive shell",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22022
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err)
assert.Equal(t, tt.expectedCmd, command, tt.description)
// When command is present, it should execute the command and exit
// When command is empty, it should open interactive shell
hasCommand := command != ""
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
})
}
}
func TestSSHCommand_FlagHandling(t *testing.T) {
// Test that flags after hostname are not parsed by netbird but passed to SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "ls with -la flag should not be parsed by netbird",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
},
{
name: "command with netbird-like flags should be passed through",
args: []string{"hostname", "echo", "--help"},
expectedHost: "hostname",
expectedCmd: "echo --help",
expectError: false,
description: "--help should be passed to echo, not parsed by netbird",
},
{
name: "command with -p flag should not conflict with SSH port flag",
args: []string{"hostname", "ps", "-p", "1234"},
expectedHost: "hostname",
expectedCmd: "ps -p 1234",
expectError: false,
description: "ps -p should be passed to ps command, not parsed as port",
},
{
name: "tar with flags should be passed through",
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
expectedHost: "hostname",
expectedCmd: "tar -czf backup.tar.gz /home",
expectError: false,
description: "tar flags should be passed to tar command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22022
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
// should not parse -la as netbird flags but pass them to the SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "original issue: ls -la should be preserved",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "The original failing case should now work",
},
{
name: "ls -l should be preserved",
args: []string{"hostname", "ls", "-l"},
expectedHost: "hostname",
expectedCmd: "ls -l",
expectError: false,
description: "Single letter flags should be preserved",
},
{
name: "SSH port flag should work",
args: []string{"-p", "2222", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedCmd: "ls -la",
expectError: false,
description: "SSH -p flag should be parsed, command flags preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22022
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// Check port for the test case with -p flag
if len(tt.args) > 0 && tt.args[0] == "-p" {
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
}
})
}
}