mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Complete overhaul
This commit is contained in:
@@ -35,7 +35,6 @@ const (
|
||||
wireguardPortFlag = "wireguard-port"
|
||||
networkMonitorFlag = "network-monitor"
|
||||
disableAutoConnectFlag = "disable-auto-connect"
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||
dnsRouteIntervalFlag = "dns-router-interval"
|
||||
systemInfoFlag = "system-info"
|
||||
@@ -67,7 +66,6 @@ var (
|
||||
customDNSAddress string
|
||||
rosenpassEnabled bool
|
||||
rosenpassPermissive bool
|
||||
serverSSHAllowed bool
|
||||
interfaceName string
|
||||
wireguardPort uint16
|
||||
networkMonitor bool
|
||||
@@ -182,7 +180,6 @@ func init() {
|
||||
)
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||
|
||||
|
||||
@@ -14,113 +14,375 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
username string
|
||||
host string
|
||||
command string
|
||||
const (
|
||||
sshUsernameDesc = "SSH username"
|
||||
hostArgumentRequired = "host argument required"
|
||||
|
||||
serverSSHAllowedFlag = "allow-server-ssh"
|
||||
enableSSHRootFlag = "enable-ssh-root"
|
||||
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
)
|
||||
|
||||
var (
|
||||
port int
|
||||
username string
|
||||
host string
|
||||
command string
|
||||
localForwards []string
|
||||
remoteForwards []string
|
||||
strictHostKeyChecking bool
|
||||
knownHostsFile string
|
||||
identityFile string
|
||||
)
|
||||
|
||||
var (
|
||||
serverSSHAllowed bool
|
||||
enableSSHRoot bool
|
||||
enableSSHSFTP bool
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
||||
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
|
||||
sshCmd.AddCommand(sshSftpCmd)
|
||||
}
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
Use: "ssh [user@]host [command]",
|
||||
Use: "ssh [flags] [user@]host [command]",
|
||||
Short: "Connect to a NetBird peer via SSH",
|
||||
Long: `Connect to a NetBird peer using SSH.
|
||||
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
|
||||
|
||||
Port Forwarding:
|
||||
-L [bind_address:]port:host:hostport Local port forwarding
|
||||
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
|
||||
-R [bind_address:]port:host:hostport Remote port forwarding
|
||||
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
|
||||
|
||||
SSH Options:
|
||||
-p, --port int Remote SSH port (default 22)
|
||||
-u, --user string SSH username
|
||||
--login string SSH username (alias for --user)
|
||||
--strict-host-key-checking Enable strict host key checking (default: true)
|
||||
-o, --known-hosts string Path to known_hosts file
|
||||
-i, --identity string Path to SSH private key file
|
||||
|
||||
Examples:
|
||||
netbird ssh peer-hostname
|
||||
netbird ssh root@peer-hostname
|
||||
netbird ssh --login root peer-hostname
|
||||
netbird ssh peer-hostname
|
||||
netbird ssh peer-hostname ls -la
|
||||
netbird ssh peer-hostname whoami`,
|
||||
netbird ssh peer-hostname whoami
|
||||
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
|
||||
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
|
||||
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
|
||||
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
|
||||
DisableFlagParsing: true,
|
||||
Args: validateSSHArgsWithoutFlagParsing,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check if help was requested
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
return cmd.Help()
|
||||
RunE: sshFn,
|
||||
Aliases: []string{"ssh"},
|
||||
}
|
||||
|
||||
func sshFn(cmd *cobra.Command, args []string) error {
|
||||
// Check if help was requested
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
return cmd.Help()
|
||||
}
|
||||
}
|
||||
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
// Global flags were already parsed by validateSSHArgsWithoutFlagParsing
|
||||
// No additional parsing needed here
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if logFile != "" && logFile != "/var/log/netbird/client.log" {
|
||||
logOutput = logFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
if err := runSSH(sshctx, host, cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
|
||||
func getEnvOrDefault(flagName, defaultValue string) string {
|
||||
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
|
||||
return envValue
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// resetSSHGlobals sets SSH globals to their default values
|
||||
func resetSSHGlobals() {
|
||||
port = sshserver.DefaultSSHPort
|
||||
username = ""
|
||||
host = ""
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
strictHostKeyChecking = true
|
||||
knownHostsFile = ""
|
||||
identityFile = ""
|
||||
}
|
||||
|
||||
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
|
||||
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
|
||||
var localForwardFlags []string
|
||||
var remoteForwardFlags []string
|
||||
var filteredArgs []string
|
||||
|
||||
for i := 0; i < len(args); i++ {
|
||||
arg := args[i]
|
||||
if strings.HasPrefix(arg, "-L") {
|
||||
if arg == "-L" && i+1 < len(args) {
|
||||
localForwardFlags = append(localForwardFlags, args[i+1])
|
||||
i++
|
||||
} else if len(arg) > 2 {
|
||||
localForwardFlags = append(localForwardFlags, arg[2:])
|
||||
}
|
||||
} else if strings.HasPrefix(arg, "-R") {
|
||||
if arg == "-R" && i+1 < len(args) {
|
||||
remoteForwardFlags = append(remoteForwardFlags, args[i+1])
|
||||
i++
|
||||
} else if len(arg) > 2 {
|
||||
remoteForwardFlags = append(remoteForwardFlags, arg[2:])
|
||||
}
|
||||
} else {
|
||||
filteredArgs = append(filteredArgs, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return filteredArgs, localForwardFlags, remoteForwardFlags
|
||||
}
|
||||
|
||||
// extractGlobalFlags parses global flags that were passed before 'ssh' command
|
||||
func extractGlobalFlags(args []string) {
|
||||
sshPos := findSSHCommandPosition(args)
|
||||
if sshPos == -1 {
|
||||
return
|
||||
}
|
||||
|
||||
globalArgs := args[:sshPos]
|
||||
parseGlobalArgs(globalArgs)
|
||||
}
|
||||
|
||||
// findSSHCommandPosition locates the 'ssh' command in the argument list
|
||||
func findSSHCommandPosition(args []string) int {
|
||||
for i, arg := range args {
|
||||
if arg == "ssh" {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
const (
|
||||
configFlag = "config"
|
||||
logLevelFlag = "log-level"
|
||||
logFileFlag = "log-file"
|
||||
)
|
||||
|
||||
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||
func parseGlobalArgs(globalArgs []string) {
|
||||
flagHandlers := map[string]func(string){
|
||||
configFlag: func(value string) { configPath = value },
|
||||
logLevelFlag: func(value string) { logLevel = value },
|
||||
logFileFlag: func(value string) { logFile = value },
|
||||
}
|
||||
|
||||
shortFlags := map[string]string{
|
||||
"c": configFlag,
|
||||
"l": logLevelFlag,
|
||||
}
|
||||
|
||||
for i := 0; i < len(globalArgs); i++ {
|
||||
arg := globalArgs[i]
|
||||
|
||||
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
|
||||
i = nextIndex
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseFlag handles generic flag parsing for both long and short forms
|
||||
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
|
||||
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex
|
||||
}
|
||||
|
||||
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
|
||||
flagHandlers[parsedValue.flagName](parsedValue.value)
|
||||
return true, currentIndex + 1
|
||||
}
|
||||
|
||||
return false, currentIndex
|
||||
}
|
||||
|
||||
type parsedFlag struct {
|
||||
flagName string
|
||||
value string
|
||||
}
|
||||
|
||||
// parseEqualsFormat handles --flag=value and -f=value formats
|
||||
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if !strings.Contains(arg, "=") {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(arg, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "--") {
|
||||
flagName := strings.TrimPrefix(parts[0], "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
|
||||
shortFlag := strings.TrimPrefix(parts[0], "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: parts[1]}, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
// parseSpacedFormat handles --flag value and -f value formats
|
||||
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
|
||||
if currentIndex+1 >= len(args) {
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
if strings.HasPrefix(arg, "--") {
|
||||
flagName := strings.TrimPrefix(arg, "--")
|
||||
if _, exists := flagHandlers[flagName]; exists {
|
||||
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
|
||||
}
|
||||
}
|
||||
|
||||
ctx := internal.CtxInitState(cmd.Context())
|
||||
|
||||
config, err := internal.UpdateConfig(internal.ConfigInput{
|
||||
ConfigPath: configPath,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("update config: %w", err)
|
||||
}
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
|
||||
sshctx, cancel := context.WithCancel(ctx)
|
||||
|
||||
go func() {
|
||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||
cmd.Printf("Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
|
||||
shortFlag := strings.TrimPrefix(arg, "-")
|
||||
if longFlag, exists := shortFlags[shortFlag]; exists {
|
||||
if _, exists := flagHandlers[longFlag]; exists {
|
||||
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
|
||||
}
|
||||
cancel()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-sig:
|
||||
cancel()
|
||||
case <-sshctx.Done():
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
return parsedFlag{}, false
|
||||
}
|
||||
|
||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string, *string) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
defaultLogFile := getEnvOrDefault("LOG_FILE", logFile)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
|
||||
portFlag := fs.Int("p", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
||||
userFlag := fs.String("u", "", sshUsernameDesc)
|
||||
fs.String("user", "", sshUsernameDesc)
|
||||
loginFlag := fs.String("login", "", sshUsernameDesc+" (alias for --user)")
|
||||
|
||||
strictHostKeyCheckingFlag := fs.Bool("strict-host-key-checking", true, "Enable strict host key checking")
|
||||
knownHostsFlag := fs.String("o", "", "Path to known_hosts file")
|
||||
fs.String("known-hosts", "", "Path to known_hosts file")
|
||||
identityFlag := fs.String("i", "", "Path to SSH private key file")
|
||||
fs.String("identity", "", "Path to SSH private key file")
|
||||
|
||||
configFlag := fs.String("c", defaultConfigPath, "Netbird config file location")
|
||||
fs.String("config", defaultConfigPath, "Netbird config file location")
|
||||
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
||||
logFileFlag := fs.String("log-file", defaultLogFile, "sets Netbird log path")
|
||||
|
||||
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag, logFileFlag
|
||||
}
|
||||
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("host argument required")
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
// Reset globals to defaults
|
||||
port = nbssh.DefaultSSHPort
|
||||
username = ""
|
||||
host = ""
|
||||
command = ""
|
||||
resetSSHGlobals()
|
||||
|
||||
// Create a new FlagSet for parsing SSH-specific flags
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil) // Suppress error output
|
||||
|
||||
// Define SSH-specific flags
|
||||
portFlag := fs.Int("p", nbssh.DefaultSSHPort, "SSH port")
|
||||
fs.Int("port", nbssh.DefaultSSHPort, "SSH port")
|
||||
userFlag := fs.String("u", "", "SSH username")
|
||||
fs.String("user", "", "SSH username")
|
||||
loginFlag := fs.String("login", "", "SSH username (alias for --user)")
|
||||
|
||||
// Parse flags until we hit the hostname (first non-flag argument)
|
||||
err := fs.Parse(args)
|
||||
if err != nil {
|
||||
// If flag parsing fails, treat everything as hostname + command
|
||||
// This handles cases like `ssh hostname ls -la` where `-la` should be part of the command
|
||||
return parseHostnameAndCommand(args)
|
||||
// Extract global flags that were passed before 'ssh' by checking original command line
|
||||
if len(os.Args) > 2 {
|
||||
extractGlobalFlags(os.Args[1:])
|
||||
}
|
||||
|
||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||
|
||||
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag, logFileFlag := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
return parseHostnameAndCommand(filteredArgs)
|
||||
}
|
||||
|
||||
// Get the remaining args (hostname and command)
|
||||
remaining := fs.Args()
|
||||
if len(remaining) < 1 {
|
||||
return errors.New("host argument required")
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
// Set parsed values
|
||||
@@ -131,12 +393,31 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
username = *loginFlag
|
||||
}
|
||||
|
||||
strictHostKeyChecking = *strictHostKeyCheckingFlag
|
||||
knownHostsFile = *knownHostsFlag
|
||||
identityFile = *identityFlag
|
||||
|
||||
// Global flags were already extracted in extractGlobalFlags()
|
||||
// Only override with SSH-specific flags if they were explicitly provided
|
||||
if *configFlag != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = *configFlag
|
||||
}
|
||||
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = *logLevelFlag
|
||||
}
|
||||
if *logFileFlag != getEnvOrDefault("LOG_FILE", logFile) {
|
||||
logFile = *logFileFlag
|
||||
}
|
||||
|
||||
localForwards = localForwardFlags
|
||||
remoteForwards = remoteForwardFlags
|
||||
|
||||
return parseHostnameAndCommand(remaining)
|
||||
}
|
||||
|
||||
func parseHostnameAndCommand(args []string) error {
|
||||
if len(args) < 1 {
|
||||
return errors.New("host argument required")
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
// Parse hostname (possibly with user@host format)
|
||||
@@ -174,43 +455,221 @@ func parseHostnameAndCommand(args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
|
||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
c, err := nbssh.DialWithKey(ctx, target, username, pemKey)
|
||||
|
||||
var c *sshclient.Client
|
||||
var err error
|
||||
|
||||
if strictHostKeyChecking {
|
||||
c, err = sshclient.DialWithOptions(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
DaemonAddr: daemonAddr,
|
||||
})
|
||||
} else {
|
||||
c, err = sshclient.DialInsecure(ctx, target, username)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||
cmd.Printf("\nTroubleshooting steps:\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
|
||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||
if strictHostKeyChecking {
|
||||
cmd.Printf(" 4. Try --strict-host-key-checking=false to bypass host key verification\n")
|
||||
}
|
||||
cmd.Printf("\n")
|
||||
return fmt.Errorf("dial %s: %w", target, err)
|
||||
}
|
||||
|
||||
sshCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
_ = c.Close()
|
||||
<-sshCtx.Done()
|
||||
if err := c.Close(); err != nil {
|
||||
cmd.Printf("Error closing SSH connection: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
|
||||
return fmt.Errorf("start port forwarding: %w", err)
|
||||
}
|
||||
|
||||
if command != "" {
|
||||
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
return executeSSHCommand(sshCtx, c, command)
|
||||
}
|
||||
return openSSHTerminal(sshCtx, c)
|
||||
}
|
||||
|
||||
// executeSSHCommand executes a command over SSH.
|
||||
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
|
||||
if err := c.ExecuteCommandWithIO(ctx, command); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
} else {
|
||||
if err := c.OpenTerminal(ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// openSSHTerminal opens an interactive SSH terminal.
|
||||
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
|
||||
if err := c.OpenTerminal(ctx); err != nil {
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("open terminal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// startPortForwarding starts local and remote port forwarding based on command line flags
|
||||
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
|
||||
for _, forward := range localForwards {
|
||||
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("local port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, forward := range remoteForwards {
|
||||
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
|
||||
return fmt.Errorf("remote port forward %s: %w", forward, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", "SSH username")
|
||||
sshCmd.PersistentFlags().StringVar(&username, "login", "", "SSH username (alias for --user)")
|
||||
// parseAndStartLocalForward parses and starts a local port forward (-L)
|
||||
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Local port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
|
||||
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
|
||||
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
|
||||
|
||||
go func() {
|
||||
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
|
||||
cmd.Printf("Remote port forward error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
|
||||
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
|
||||
func parsePortForwardSpec(spec string) (string, string, error) {
|
||||
// Support formats:
|
||||
// port:host:hostport -> localhost:port -> host:hostport
|
||||
// host:port:host:hostport -> host:port -> host:hostport
|
||||
// [host]:port:host:hostport -> [host]:port -> host:hostport
|
||||
// port:unix_socket_path -> localhost:port -> unix_socket_path
|
||||
// host:port:unix_socket_path -> host:port -> unix_socket_path
|
||||
|
||||
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
|
||||
return parseIPv6ForwardSpec(spec)
|
||||
}
|
||||
|
||||
parts := strings.Split(spec, ":")
|
||||
if len(parts) < 2 {
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
|
||||
}
|
||||
|
||||
switch len(parts) {
|
||||
case 2:
|
||||
return parseTwoPartForwardSpec(parts, spec)
|
||||
case 3:
|
||||
return parseThreePartForwardSpec(parts)
|
||||
case 4:
|
||||
return parseFourPartForwardSpec(parts)
|
||||
default:
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
|
||||
}
|
||||
}
|
||||
|
||||
// parseTwoPartForwardSpec handles "port:unix_socket" format.
|
||||
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
|
||||
if isUnixSocket(parts[1]) {
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
|
||||
}
|
||||
|
||||
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
|
||||
func parseThreePartForwardSpec(parts []string) (string, string, error) {
|
||||
if isUnixSocket(parts[2]) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
localAddr := "localhost:" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
|
||||
func parseFourPartForwardSpec(parts []string) (string, string, error) {
|
||||
localHost := normalizeLocalHost(parts[0])
|
||||
localAddr := localHost + ":" + parts[1]
|
||||
remoteAddr := parts[2] + ":" + parts[3]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
|
||||
func parseIPv6ForwardSpec(spec string) (string, string, error) {
|
||||
idx := strings.Index(spec, "]:")
|
||||
if idx == -1 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
|
||||
}
|
||||
|
||||
ipv6Host := spec[:idx+1]
|
||||
remaining := spec[idx+2:]
|
||||
|
||||
parts := strings.Split(remaining, ":")
|
||||
if len(parts) != 3 {
|
||||
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
|
||||
}
|
||||
|
||||
localAddr := ipv6Host + ":" + parts[0]
|
||||
remoteAddr := parts[1] + ":" + parts[2]
|
||||
return localAddr, remoteAddr, nil
|
||||
}
|
||||
|
||||
// isUnixSocket checks if a path is a Unix socket path.
|
||||
func isUnixSocket(path string) bool {
|
||||
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
|
||||
}
|
||||
|
||||
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
|
||||
func normalizeLocalHost(host string) string {
|
||||
if host == "*" {
|
||||
return "0.0.0.0"
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
74
client/cmd/ssh_exec_unix.go
Normal file
74
client/cmd/ssh_exec_unix.go
Normal file
@@ -0,0 +1,74 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sshExecUID uint32
|
||||
sshExecGID uint32
|
||||
sshExecGroups []uint
|
||||
sshExecWorkingDir string
|
||||
sshExecShell string
|
||||
sshExecCommand string
|
||||
sshExecPTY bool
|
||||
)
|
||||
|
||||
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
|
||||
var sshExecCmd = &cobra.Command{
|
||||
Use: "exec",
|
||||
Short: "Internal SSH execution with privilege dropping (hidden)",
|
||||
Hidden: true,
|
||||
RunE: runSSHExec,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
|
||||
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
|
||||
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
|
||||
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
|
||||
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
|
||||
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
|
||||
|
||||
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
sshCmd.AddCommand(sshExecCmd)
|
||||
}
|
||||
|
||||
// runSSHExec handles the SSH exec subcommand execution.
|
||||
func runSSHExec(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sshExecGroups {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sshExecUID,
|
||||
GID: sshExecGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sshExecWorkingDir,
|
||||
Shell: sshExecShell,
|
||||
Command: sshExecCommand,
|
||||
PTY: sshExecPTY,
|
||||
}
|
||||
|
||||
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_unix.go
Normal file
94
client/cmd/ssh_sftp_unix.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build unix
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpUID uint32
|
||||
sftpGID uint32
|
||||
sftpGroupsInt []uint
|
||||
sftpWorkingDir string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with privilege dropping (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
|
||||
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
|
||||
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
privilegeDropper := sshserver.NewPrivilegeDropper()
|
||||
|
||||
var groups []uint32
|
||||
for _, groupInt := range sftpGroupsInt {
|
||||
groups = append(groups, uint32(groupInt))
|
||||
}
|
||||
|
||||
config := sshserver.ExecutorConfig{
|
||||
UID: sftpUID,
|
||||
GID: sftpGID,
|
||||
Groups: groups,
|
||||
WorkingDir: sftpWorkingDir,
|
||||
Shell: "",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
cmd.PrintErrf("privilege drop failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
cmd.PrintErrf("SFTP server close error: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Tracef("starting SFTP server with dropped privileges")
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
os.Exit(sshserver.ExitCodeSuccess)
|
||||
return nil
|
||||
}
|
||||
94
client/cmd/ssh_sftp_windows.go
Normal file
94
client/cmd/ssh_sftp_windows.go
Normal file
@@ -0,0 +1,94 @@
|
||||
//go:build windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/user"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
)
|
||||
|
||||
var (
|
||||
sftpWorkingDir string
|
||||
windowsUsername string
|
||||
windowsDomain string
|
||||
)
|
||||
|
||||
var sshSftpCmd = &cobra.Command{
|
||||
Use: "sftp",
|
||||
Short: "SFTP server with user switching for Windows (internal use)",
|
||||
Hidden: true,
|
||||
RunE: sftpMain,
|
||||
}
|
||||
|
||||
func init() {
|
||||
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
|
||||
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
|
||||
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
|
||||
}
|
||||
|
||||
func sftpMain(cmd *cobra.Command, _ []string) error {
|
||||
return sftpMainDirect(cmd)
|
||||
}
|
||||
|
||||
func sftpMainDirect(cmd *cobra.Command) error {
|
||||
currentUser, err := user.Current()
|
||||
if err != nil {
|
||||
cmd.PrintErrf("failed to get current user: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
|
||||
if windowsUsername != "" {
|
||||
expectedUsername := windowsUsername
|
||||
if windowsDomain != "" {
|
||||
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
|
||||
}
|
||||
if currentUser.Username != expectedUsername && currentUser.Username != windowsUsername {
|
||||
cmd.PrintErrf("user switching failed\n")
|
||||
os.Exit(sshserver.ExitCodeValidationFail)
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
|
||||
|
||||
if sftpWorkingDir != "" {
|
||||
if err := os.Chdir(sftpWorkingDir); err != nil {
|
||||
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
sftpServer, err := sftp.NewServer(struct {
|
||||
io.Reader
|
||||
io.WriteCloser
|
||||
}{
|
||||
Reader: os.Stdin,
|
||||
WriteCloser: os.Stdout,
|
||||
})
|
||||
if err != nil {
|
||||
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("SFTP server close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Debugf("starting SFTP server")
|
||||
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
|
||||
cmd.PrintErrf("SFTP server error: %v\n", err)
|
||||
os.Exit(sshserver.ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
os.Exit(sshserver.ExitCodeSuccess)
|
||||
return nil
|
||||
}
|
||||
@@ -22,7 +22,7 @@ func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
args: []string{"hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22022,
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
@@ -30,7 +30,7 @@ func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
args: []string{"user@hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "user",
|
||||
expectedPort: 22022,
|
||||
expectedPort: 22,
|
||||
expectedCmd: "",
|
||||
},
|
||||
{
|
||||
@@ -38,7 +38,7 @@ func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
args: []string{"hostname", "echo", "hello"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22022,
|
||||
expectedPort: 22,
|
||||
expectedCmd: "echo hello",
|
||||
},
|
||||
{
|
||||
@@ -46,7 +46,7 @@ func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
args: []string{"hostname", "ls", "-la", "/tmp"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22022,
|
||||
expectedPort: 22,
|
||||
expectedCmd: "ls -la /tmp",
|
||||
},
|
||||
{
|
||||
@@ -54,7 +54,7 @@ func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
args: []string{"hostname", "--", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedUser: "",
|
||||
expectedPort: 22022,
|
||||
expectedPort: 22,
|
||||
expectedCmd: "-- ls -la",
|
||||
},
|
||||
}
|
||||
@@ -64,7 +64,7 @@ func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22022
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
// Mock command for testing
|
||||
@@ -78,7 +78,7 @@ func TestSSHCommand_FlagParsing(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
if tt.expectedUser != "" {
|
||||
assert.Equal(t, tt.expectedUser, username, "username mismatch")
|
||||
@@ -128,12 +128,12 @@ func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22022
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
@@ -192,12 +192,12 @@ func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22022
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
@@ -258,7 +258,7 @@ func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22022
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
@@ -269,7 +269,7 @@ func TestSSHCommand_FlagHandling(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
})
|
||||
@@ -318,7 +318,7 @@ func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22022
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
cmd := sshCmd
|
||||
@@ -329,7 +329,7 @@ func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description)
|
||||
|
||||
@@ -340,3 +340,330 @@ func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local port forwarding -L",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Single -L flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "remote port forwarding -R",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectError: false,
|
||||
description: "Single -R flag should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple local port forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Multiple -L flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "multiple remote port forwards",
|
||||
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Multiple -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "mixed local and remote forwards",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectError: false,
|
||||
description: "Mixed -L and -R flags should be parsed correctly",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with bind address",
|
||||
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "port forwarding with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectError: false,
|
||||
description: "Port forwarding with command should work",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePortForward(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
spec string
|
||||
expectedLocal string
|
||||
expectedRemote string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "simple port forward",
|
||||
spec: "8080:localhost:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Simple port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with bind address",
|
||||
spec: "127.0.0.1:8080:localhost:80",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "bind_address:port:host:port format should work",
|
||||
},
|
||||
{
|
||||
name: "port forward to different host",
|
||||
spec: "8080:example.com:443",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "example.com:443",
|
||||
expectError: false,
|
||||
description: "Forwarding to different host should work",
|
||||
},
|
||||
{
|
||||
name: "port forward with IPv6 (needs bracket support)",
|
||||
spec: "::1:8080:localhost:80",
|
||||
expectError: true,
|
||||
description: "IPv6 without brackets fails as expected (feature to implement)",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too few parts",
|
||||
spec: "8080:localhost",
|
||||
expectError: true,
|
||||
description: "Invalid format with too few parts should fail",
|
||||
},
|
||||
{
|
||||
name: "invalid format - too many parts",
|
||||
spec: "127.0.0.1:8080:localhost:80:extra",
|
||||
expectError: true,
|
||||
description: "Invalid format with too many parts should fail",
|
||||
},
|
||||
{
|
||||
name: "empty spec",
|
||||
spec: "",
|
||||
expectError: true,
|
||||
description: "Empty spec should fail",
|
||||
},
|
||||
{
|
||||
name: "unix socket local forward",
|
||||
spec: "8080:/tmp/socket",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket forwarding should work",
|
||||
},
|
||||
{
|
||||
name: "unix socket with bind address",
|
||||
spec: "127.0.0.1:8080:/tmp/socket",
|
||||
expectedLocal: "127.0.0.1:8080",
|
||||
expectedRemote: "/tmp/socket",
|
||||
expectError: false,
|
||||
description: "Unix socket with bind address should work",
|
||||
},
|
||||
{
|
||||
name: "wildcard bind all interfaces",
|
||||
spec: "*:8080:localhost:80",
|
||||
expectedLocal: "0.0.0.0:8080",
|
||||
expectedRemote: "localhost:80",
|
||||
expectError: false,
|
||||
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
|
||||
},
|
||||
{
|
||||
name: "wildcard for port only",
|
||||
spec: "8080:*:80",
|
||||
expectedLocal: "localhost:8080",
|
||||
expectedRemote: "*:80",
|
||||
expectError: false,
|
||||
description: "Wildcard in remote host should be preserved",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err, tt.description)
|
||||
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
|
||||
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
|
||||
// Integration test for port forwarding with the actual SSH command implementation
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedHost string
|
||||
expectedLocal []string
|
||||
expectedRemote []string
|
||||
expectedCmd string
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "local forward with command",
|
||||
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{},
|
||||
expectedCmd: "echo test",
|
||||
description: "Local forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "remote forward with command",
|
||||
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{},
|
||||
expectedRemote: []string{"8080:localhost:80"},
|
||||
expectedCmd: "ls -la",
|
||||
description: "Remote forwarding should work with commands",
|
||||
},
|
||||
{
|
||||
name: "multiple forwards with user and command",
|
||||
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
|
||||
expectedHost: "hostname",
|
||||
expectedLocal: []string{"8080:localhost:80"},
|
||||
expectedRemote: []string{"9090:localhost:443"},
|
||||
expectedCmd: "ps aux",
|
||||
description: "Complex case with multiple forwards, user, and command",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Reset global variables
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
localForwards = nil
|
||||
remoteForwards = nil
|
||||
|
||||
cmd := sshCmd
|
||||
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
|
||||
require.NoError(t, err, "SSH args validation should succeed for valid input")
|
||||
|
||||
assert.Equal(t, tt.expectedHost, host, "host mismatch")
|
||||
// Handle nil vs empty slice comparison
|
||||
if len(tt.expectedLocal) == 0 {
|
||||
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
|
||||
}
|
||||
if len(tt.expectedRemote) == 0 {
|
||||
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
|
||||
} else {
|
||||
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
|
||||
}
|
||||
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHCommand_ParameterIsolation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedCmd string
|
||||
}{
|
||||
{
|
||||
name: "cmd flag passed as command",
|
||||
args: []string{"hostname", "--cmd", "echo test"},
|
||||
expectedCmd: "--cmd echo test",
|
||||
},
|
||||
{
|
||||
name: "uid flag passed as command",
|
||||
args: []string{"hostname", "--uid", "1000"},
|
||||
expectedCmd: "--uid 1000",
|
||||
},
|
||||
{
|
||||
name: "shell flag passed as command",
|
||||
args: []string{"hostname", "--shell", "/bin/bash"},
|
||||
expectedCmd: "--shell /bin/bash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
host = ""
|
||||
username = ""
|
||||
port = 22
|
||||
command = ""
|
||||
|
||||
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "hostname", host)
|
||||
assert.Equal(t, tt.expectedCmd, command)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -258,6 +258,22 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*interna
|
||||
ic.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
ic.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
ic.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
@@ -352,6 +368,22 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.ServerSSHAllowed = &serverSSHAllowed
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRootFlag).Changed {
|
||||
loginRequest.EnableSSHRoot = &enableSSHRoot
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHSFTPFlag).Changed {
|
||||
loginRequest.EnableSSHSFTP = &enableSSHSFTP
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user