[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

View File

@@ -35,7 +35,6 @@ const (
wireguardPortFlag = "wireguard-port"
networkMonitorFlag = "network-monitor"
disableAutoConnectFlag = "disable-auto-connect"
serverSSHAllowedFlag = "allow-server-ssh"
extraIFaceBlackListFlag = "extra-iface-blacklist"
dnsRouteIntervalFlag = "dns-router-interval"
enableLazyConnectionFlag = "enable-lazy-connection"
@@ -64,7 +63,6 @@ var (
customDNSAddress string
rosenpassEnabled bool
rosenpassPermissive bool
serverSSHAllowed bool
interfaceName string
wireguardPort uint16
networkMonitor bool
@@ -176,7 +174,6 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")

View File

@@ -3,125 +3,809 @@ package cmd
import (
"context"
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"slices"
"strconv"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util"
)
var (
port int
userName = "root"
host string
const (
sshUsernameDesc = "SSH username"
hostArgumentRequired = "host argument required"
serverSSHAllowedFlag = "allow-server-ssh"
enableSSHRootFlag = "enable-ssh-root"
enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
sshJWTCacheTTLFlag = "ssh-jwt-cache-ttl"
)
var sshCmd = &cobra.Command{
Use: "ssh [user@]host",
Args: func(cmd *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New("requires a host argument")
}
var (
port int
username string
host string
command string
localForwards []string
remoteForwards []string
strictHostKeyChecking bool
knownHostsFile string
identityFile string
skipCachedToken bool
requestPTY bool
)
split := strings.Split(args[0], "@")
if len(split) == 2 {
userName = split[0]
host = split[1]
} else {
host = args[0]
}
var (
serverSSHAllowed bool
enableSSHRoot bool
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
sshJWTCacheTTL int
)
return nil
},
Short: "Connect to a remote SSH server",
RunE: func(cmd *cobra.Command, args []string) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
func init() {
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer")
upCmd.PersistentFlags().BoolVar(&enableSSHRoot, enableSSHRootFlag, false, "Enable root login for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
upCmd.PersistentFlags().IntVar(&sshJWTCacheTTL, sshJWTCacheTTLFlag, 0, "SSH JWT token cache TTL in seconds (0=disabled)")
cmd.SetOut(cmd.OutOrStdout())
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
sshCmd.PersistentFlags().StringVar(&username, "login", "", sshUsernameDesc+" (alias for --user)")
sshCmd.PersistentFlags().BoolVarP(&requestPTY, "tty", "t", false, "Force pseudo-terminal allocation")
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
err := util.InitLog(logLevel, util.LogConsole)
if err != nil {
return fmt.Errorf("failed initializing log %v", err)
}
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
if !util.IsAdmin() {
cmd.Printf("error: you must have Administrator privileges to run this command\n")
return nil
}
ctx := internal.CtxInitState(cmd.Context())
sm := profilemanager.NewServiceManager(configPath)
activeProf, err := sm.GetActiveProfileState()
if err != nil {
return fmt.Errorf("get active profile: %v", err)
}
profPath, err := activeProf.FilePath()
if err != nil {
return fmt.Errorf("get active profile path: %v", err)
}
config, err := profilemanager.ReadConfig(profPath)
if err != nil {
return fmt.Errorf("read profile config: %v", err)
}
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
go func() {
// blocking
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
cmd.Printf("Error: %v\n", err)
os.Exit(1)
}
cancel()
}()
select {
case <-sig:
cancel()
case <-sshctx.Done():
}
return nil
},
sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
}
func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) error {
c, err := nbssh.DialWithKey(fmt.Sprintf("%s:%d", addr, port), userName, pemKey)
if err != nil {
cmd.Printf("Error: %v\n", err)
cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" +
"\nYou can verify the connection by running:\n\n" +
" netbird status\n\n")
return err
}
go func() {
<-ctx.Done()
err = c.Close()
if err != nil {
return
var sshCmd = &cobra.Command{
Use: "ssh [flags] [user@]host [command]",
Short: "Connect to a NetBird peer via SSH",
Long: `Connect to a NetBird peer using SSH with support for port forwarding.
Port Forwarding:
-L [bind_address:]port:host:hostport Local port forwarding
-L [bind_address:]port:/path/to/socket Local port forwarding to Unix socket
-R [bind_address:]port:host:hostport Remote port forwarding
-R [bind_address:]port:/path/to/socket Remote port forwarding to Unix socket
SSH Options:
-p, --port int Remote SSH port (default 22)
-u, --user string SSH username
--login string SSH username (alias for --user)
-t, --tty Force pseudo-terminal allocation
--strict-host-key-checking Enable strict host key checking (default: true)
-o, --known-hosts string Path to known_hosts file
Examples:
netbird ssh peer-hostname
netbird ssh root@peer-hostname
netbird ssh --login root peer-hostname
netbird ssh peer-hostname ls -la
netbird ssh peer-hostname whoami
netbird ssh -t peer-hostname tmux # Force PTY for tmux/screen
netbird ssh -t peer-hostname sudo -i # Force PTY for interactive sudo
netbird ssh -L 8080:localhost:80 peer-hostname # Local port forwarding
netbird ssh -R 9090:localhost:3000 peer-hostname # Remote port forwarding
netbird ssh -L "*:8080:localhost:80" peer-hostname # Bind to all interfaces
netbird ssh -L 8080:/tmp/socket peer-hostname # Unix socket forwarding`,
DisableFlagParsing: true,
Args: validateSSHArgsWithoutFlagParsing,
RunE: sshFn,
Aliases: []string{"ssh"},
}
func sshFn(cmd *cobra.Command, args []string) error {
for _, arg := range args {
if arg == "-h" || arg == "--help" {
return cmd.Help()
}
}
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
cmd.SetOut(cmd.OutOrStdout())
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
ctx := internal.CtxInitState(cmd.Context())
sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGTERM, syscall.SIGINT)
sshctx, cancel := context.WithCancel(ctx)
errCh := make(chan error, 1)
go func() {
if err := runSSH(sshctx, host, cmd); err != nil {
errCh <- err
}
cancel()
}()
err = c.OpenTerminal()
if err != nil {
select {
case <-sig:
cancel()
<-sshctx.Done()
return nil
case err := <-errCh:
return err
case <-sshctx.Done():
}
return nil
}
func init() {
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", nbssh.DefaultSSHPort, "Sets remote SSH port. Defaults to "+fmt.Sprint(nbssh.DefaultSSHPort))
// getEnvOrDefault checks for environment variables with WT_ and NB_ prefixes
func getEnvOrDefault(flagName, defaultValue string) string {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
return envValue
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
return envValue
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() {
port = sshserver.DefaultSSHPort
username = ""
host = ""
command = ""
localForwards = nil
remoteForwards = nil
strictHostKeyChecking = true
knownHostsFile = ""
identityFile = ""
}
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args
func parseCustomSSHFlags(args []string) ([]string, []string, []string) {
var localForwardFlags []string
var remoteForwardFlags []string
var filteredArgs []string
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case strings.HasPrefix(arg, "-L"):
localForwardFlags, i = parseForwardFlag(arg, args, i, localForwardFlags)
case strings.HasPrefix(arg, "-R"):
remoteForwardFlags, i = parseForwardFlag(arg, args, i, remoteForwardFlags)
default:
filteredArgs = append(filteredArgs, arg)
}
}
return filteredArgs, localForwardFlags, remoteForwardFlags
}
func parseForwardFlag(arg string, args []string, i int, flags []string) ([]string, int) {
if arg == "-L" || arg == "-R" {
if i+1 < len(args) {
flags = append(flags, args[i+1])
i++
}
} else if len(arg) > 2 {
flags = append(flags, arg[2:])
}
return flags, i
}
// extractGlobalFlags parses global flags that were passed before 'ssh' command
func extractGlobalFlags(args []string) {
sshPos := findSSHCommandPosition(args)
if sshPos == -1 {
return
}
globalArgs := args[:sshPos]
parseGlobalArgs(globalArgs)
}
// findSSHCommandPosition locates the 'ssh' command in the argument list
func findSSHCommandPosition(args []string) int {
for i, arg := range args {
if arg == "ssh" {
return i
}
}
return -1
}
const (
configFlag = "config"
logLevelFlag = "log-level"
logFileFlag = "log-file"
)
// parseGlobalArgs processes the global arguments and sets the corresponding variables
func parseGlobalArgs(globalArgs []string) {
flagHandlers := map[string]func(string){
configFlag: func(value string) { configPath = value },
logLevelFlag: func(value string) { logLevel = value },
logFileFlag: func(value string) {
if !slices.Contains(logFiles, value) {
logFiles = append(logFiles, value)
}
},
}
shortFlags := map[string]string{
"c": configFlag,
"l": logLevelFlag,
}
for i := 0; i < len(globalArgs); i++ {
arg := globalArgs[i]
if handled, nextIndex := parseFlag(arg, globalArgs, i, flagHandlers, shortFlags); handled {
i = nextIndex
}
}
}
// parseFlag handles generic flag parsing for both long and short forms
func parseFlag(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (bool, int) {
if parsedValue, found := parseEqualsFormat(arg, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex
}
if parsedValue, found := parseSpacedFormat(arg, args, currentIndex, flagHandlers, shortFlags); found {
flagHandlers[parsedValue.flagName](parsedValue.value)
return true, currentIndex + 1
}
return false, currentIndex
}
type parsedFlag struct {
flagName string
value string
}
// parseEqualsFormat handles --flag=value and -f=value formats
func parseEqualsFormat(arg string, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if !strings.Contains(arg, "=") {
return parsedFlag{}, false
}
parts := strings.SplitN(arg, "=", 2)
if len(parts) != 2 {
return parsedFlag{}, false
}
if strings.HasPrefix(parts[0], "--") {
flagName := strings.TrimPrefix(parts[0], "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: parts[1]}, true
}
}
if strings.HasPrefix(parts[0], "-") && len(parts[0]) == 2 {
shortFlag := strings.TrimPrefix(parts[0], "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: parts[1]}, true
}
}
}
return parsedFlag{}, false
}
// parseSpacedFormat handles --flag value and -f value formats
func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers map[string]func(string), shortFlags map[string]string) (parsedFlag, bool) {
if currentIndex+1 >= len(args) {
return parsedFlag{}, false
}
if strings.HasPrefix(arg, "--") {
flagName := strings.TrimPrefix(arg, "--")
if _, exists := flagHandlers[flagName]; exists {
return parsedFlag{flagName: flagName, value: args[currentIndex+1]}, true
}
}
if strings.HasPrefix(arg, "-") && len(arg) == 2 {
shortFlag := strings.TrimPrefix(arg, "-")
if longFlag, exists := shortFlags[shortFlag]; exists {
if _, exists := flagHandlers[longFlag]; exists {
return parsedFlag{flagName: longFlag, value: args[currentIndex+1]}, true
}
}
}
return parsedFlag{}, false
}
// createSSHFlagSet creates and configures the flag set for SSH command parsing
// sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
RequestPTY bool
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.IntVar(&flags.Port, "port", sshserver.DefaultSSHPort, "SSH port")
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "user", "", sshUsernameDesc)
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
fs.BoolVar(&flags.RequestPTY, "t", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.RequestPTY, "tty", false, "Force pseudo-terminal allocation")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.StringVar(&flags.KnownHostsFile, "known-hosts", "", "Path to known_hosts file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "log-level", defaultLogLevel, "sets Netbird log level")
return fs, flags
}
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
resetSSHGlobals()
if len(os.Args) > 2 {
extractGlobalFlags(os.Args[1:])
}
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil {
if errors.Is(err, flag.ErrHelp) {
return nil
}
return err
}
remaining := fs.Args()
if len(remaining) < 1 {
return errors.New(hostArgumentRequired)
}
port = flags.Port
if flags.Username != "" {
username = flags.Username
} else if flags.Login != "" {
username = flags.Login
}
requestPTY = flags.RequestPTY
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
}
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = flags.LogLevel
}
localForwards = localForwardFlags
remoteForwards = remoteForwardFlags
return parseHostnameAndCommand(remaining)
}
func parseHostnameAndCommand(args []string) error {
if len(args) < 1 {
return errors.New(hostArgumentRequired)
}
arg := args[0]
if strings.Contains(arg, "@") {
parts := strings.SplitN(arg, "@", 2)
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.New("invalid user@host format")
}
if username == "" {
username = parts[0]
}
host = parts[1]
} else {
host = arg
}
if username == "" {
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
username = sudoUser
} else if currentUser, err := user.Current(); err == nil {
username = currentUser.Username
} else {
username = "root"
}
}
// Everything after hostname becomes the command
if len(args) > 1 {
command = strings.Join(args[1:], " ")
}
return nil
}
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
})
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
return fmt.Errorf("dial %s: %w", target, err)
}
sshCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-sshCtx.Done()
if err := c.Close(); err != nil {
cmd.Printf("Error closing SSH connection: %v\n", err)
}
}()
if err := startPortForwarding(sshCtx, c, cmd); err != nil {
return fmt.Errorf("start port forwarding: %w", err)
}
if command != "" {
return executeSSHCommand(sshCtx, c, command)
}
return openSSHTerminal(sshCtx, c)
}
// executeSSHCommand executes a command over SSH.
func executeSSHCommand(ctx context.Context, c *sshclient.Client, command string) error {
var err error
if requestPTY {
err = c.ExecuteCommandWithPTY(ctx, command)
} else {
err = c.ExecuteCommandWithIO(ctx, command)
}
if err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitErr *ssh.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.ExitStatus())
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote command exited without exit status: %v", err)
return nil
}
return fmt.Errorf("execute command: %w", err)
}
return nil
}
// openSSHTerminal opens an interactive SSH terminal.
func openSSHTerminal(ctx context.Context, c *sshclient.Client) error {
if err := c.OpenTerminal(ctx); err != nil {
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return nil
}
var exitMissingErr *ssh.ExitMissingError
if errors.As(err, &exitMissingErr) {
log.Debugf("Remote terminal exited without exit status: %v", err)
return nil
}
return fmt.Errorf("open terminal: %w", err)
}
return nil
}
// startPortForwarding starts local and remote port forwarding based on command line flags
func startPortForwarding(ctx context.Context, c *sshclient.Client, cmd *cobra.Command) error {
for _, forward := range localForwards {
if err := parseAndStartLocalForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("local port forward %s: %w", forward, err)
}
}
for _, forward := range remoteForwards {
if err := parseAndStartRemoteForward(ctx, c, forward, cmd); err != nil {
return fmt.Errorf("remote port forward %s: %w", forward, err)
}
}
return nil
}
// parseAndStartLocalForward parses and starts a local port forward (-L)
func parseAndStartLocalForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
localAddr, remoteAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Local port forwarding: %s -> %s\n", localAddr, remoteAddr)
go func() {
if err := c.LocalPortForward(ctx, localAddr, remoteAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Local port forward error: %v\n", err)
}
}()
return nil
}
// parseAndStartRemoteForward parses and starts a remote port forward (-R)
func parseAndStartRemoteForward(ctx context.Context, c *sshclient.Client, forward string, cmd *cobra.Command) error {
remoteAddr, localAddr, err := parsePortForwardSpec(forward)
if err != nil {
return err
}
cmd.Printf("Remote port forwarding: %s -> %s\n", remoteAddr, localAddr)
go func() {
if err := c.RemotePortForward(ctx, remoteAddr, localAddr); err != nil && !errors.Is(err, context.Canceled) {
cmd.Printf("Remote port forward error: %v\n", err)
}
}()
return nil
}
// parsePortForwardSpec parses port forward specifications like "8080:localhost:80" or "[::1]:8080:localhost:80".
// Also supports Unix sockets like "8080:/tmp/socket" or "127.0.0.1:8080:/tmp/socket".
func parsePortForwardSpec(spec string) (string, string, error) {
// Support formats:
// port:host:hostport -> localhost:port -> host:hostport
// host:port:host:hostport -> host:port -> host:hostport
// [host]:port:host:hostport -> [host]:port -> host:hostport
// port:unix_socket_path -> localhost:port -> unix_socket_path
// host:port:unix_socket_path -> host:port -> unix_socket_path
if strings.HasPrefix(spec, "[") && strings.Contains(spec, "]:") {
return parseIPv6ForwardSpec(spec)
}
parts := strings.Split(spec, ":")
if len(parts) < 2 {
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_target)", spec)
}
switch len(parts) {
case 2:
return parseTwoPartForwardSpec(parts, spec)
case 3:
return parseThreePartForwardSpec(parts)
case 4:
return parseFourPartForwardSpec(parts)
default:
return "", "", fmt.Errorf("invalid port forward specification: %s", spec)
}
}
// parseTwoPartForwardSpec handles "port:unix_socket" format.
func parseTwoPartForwardSpec(parts []string, spec string) (string, string, error) {
if isUnixSocket(parts[1]) {
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1]
return localAddr, remoteAddr, nil
}
return "", "", fmt.Errorf("invalid port forward specification: %s (expected format: [local_host:]local_port:remote_host:remote_port or [local_host:]local_port:unix_socket)", spec)
}
// parseThreePartForwardSpec handles "port:host:hostport" or "host:port:unix_socket" formats.
func parseThreePartForwardSpec(parts []string) (string, string, error) {
if isUnixSocket(parts[2]) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2]
return localAddr, remoteAddr, nil
}
localAddr := "localhost:" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// parseFourPartForwardSpec handles "host:port:host:hostport" format.
func parseFourPartForwardSpec(parts []string) (string, string, error) {
localHost := normalizeLocalHost(parts[0])
localAddr := localHost + ":" + parts[1]
remoteAddr := parts[2] + ":" + parts[3]
return localAddr, remoteAddr, nil
}
// parseIPv6ForwardSpec handles "[host]:port:host:hostport" format.
func parseIPv6ForwardSpec(spec string) (string, string, error) {
idx := strings.Index(spec, "]:")
if idx == -1 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s", spec)
}
ipv6Host := spec[:idx+1]
remaining := spec[idx+2:]
parts := strings.Split(remaining, ":")
if len(parts) != 3 {
return "", "", fmt.Errorf("invalid IPv6 port forward specification: %s (expected [ipv6]:port:host:hostport)", spec)
}
localAddr := ipv6Host + ":" + parts[0]
remoteAddr := parts[1] + ":" + parts[2]
return localAddr, remoteAddr, nil
}
// isUnixSocket checks if a path is a Unix socket path.
func isUnixSocket(path string) bool {
return strings.HasPrefix(path, "/") || strings.HasPrefix(path, "./")
}
// normalizeLocalHost converts "*" to "0.0.0.0" for binding to all interfaces.
func normalizeLocalHost(host string) string {
if host == "*" {
return "0.0.0.0"
}
return host
}
var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}
func sshProxyFn(cmd *cobra.Command, args []string) error {
logOutput := "console"
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
defer func() {
if err := proxy.Close(); err != nil {
log.Debugf("close SSH proxy: %v", err)
}
}()
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}
return nil
}
var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -0,0 +1,74 @@
//go:build unix
package cmd
import (
"fmt"
"os"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sshExecUID uint32
sshExecGID uint32
sshExecGroups []uint
sshExecWorkingDir string
sshExecShell string
sshExecCommand string
sshExecPTY bool
)
// sshExecCmd represents the hidden ssh exec subcommand for privilege dropping
var sshExecCmd = &cobra.Command{
Use: "exec",
Short: "Internal SSH execution with privilege dropping (hidden)",
Hidden: true,
RunE: runSSHExec,
}
func init() {
sshExecCmd.Flags().Uint32Var(&sshExecUID, "uid", 0, "Target user ID")
sshExecCmd.Flags().Uint32Var(&sshExecGID, "gid", 0, "Target group ID")
sshExecCmd.Flags().UintSliceVar(&sshExecGroups, "groups", nil, "Supplementary group IDs (can be repeated)")
sshExecCmd.Flags().StringVar(&sshExecWorkingDir, "working-dir", "", "Working directory")
sshExecCmd.Flags().StringVar(&sshExecShell, "shell", "/bin/sh", "Shell to execute")
sshExecCmd.Flags().BoolVar(&sshExecPTY, "pty", false, "Request PTY (will fail as executor doesn't support PTY)")
sshExecCmd.Flags().StringVar(&sshExecCommand, "cmd", "", "Command to execute")
if err := sshExecCmd.MarkFlagRequired("uid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark uid flag as required: %v\n", err)
os.Exit(1)
}
if err := sshExecCmd.MarkFlagRequired("gid"); err != nil {
_, _ = fmt.Fprintf(os.Stderr, "failed to mark gid flag as required: %v\n", err)
os.Exit(1)
}
sshCmd.AddCommand(sshExecCmd)
}
// runSSHExec handles the SSH exec subcommand execution.
func runSSHExec(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sshExecGroups {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sshExecUID,
GID: sshExecGID,
Groups: groups,
WorkingDir: sshExecWorkingDir,
Shell: sshExecShell,
Command: sshExecCommand,
PTY: sshExecPTY,
}
privilegeDropper.ExecuteWithPrivilegeDrop(cmd.Context(), config)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build unix
package cmd
import (
"errors"
"io"
"os"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpUID uint32
sftpGID uint32
sftpGroupsInt []uint
sftpWorkingDir string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with privilege dropping (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().Uint32Var(&sftpUID, "uid", 0, "Target user ID")
sshSftpCmd.Flags().Uint32Var(&sftpGID, "gid", 0, "Target group ID")
sshSftpCmd.Flags().UintSliceVar(&sftpGroupsInt, "groups", nil, "Supplementary group IDs (can be repeated)")
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
privilegeDropper := sshserver.NewPrivilegeDropper()
var groups []uint32
for _, groupInt := range sftpGroupsInt {
groups = append(groups, uint32(groupInt))
}
config := sshserver.ExecutorConfig{
UID: sftpUID,
GID: sftpGID,
Groups: groups,
WorkingDir: sftpWorkingDir,
Shell: "",
Command: "",
}
log.Tracef("dropping privileges for SFTP to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
if err := privilegeDropper.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
cmd.PrintErrf("privilege drop failed: %v\n", err)
os.Exit(sshserver.ExitCodePrivilegeDropFail)
}
if config.WorkingDir != "" {
if err := os.Chdir(config.WorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", config.WorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Tracef("starting SFTP server with dropped privileges")
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeShellExecFail)
}
if closeErr := sftpServer.Close(); closeErr != nil {
cmd.PrintErrf("SFTP server close error: %v\n", closeErr)
}
os.Exit(sshserver.ExitCodeSuccess)
return nil
}

View File

@@ -0,0 +1,94 @@
//go:build windows
package cmd
import (
"errors"
"fmt"
"io"
"os"
"os/user"
"strings"
"github.com/pkg/sftp"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
var (
sftpWorkingDir string
windowsUsername string
windowsDomain string
)
var sshSftpCmd = &cobra.Command{
Use: "sftp",
Short: "SFTP server with user switching for Windows (internal use)",
Hidden: true,
RunE: sftpMain,
}
func init() {
sshSftpCmd.Flags().StringVar(&sftpWorkingDir, "working-dir", "", "Working directory")
sshSftpCmd.Flags().StringVar(&windowsUsername, "windows-username", "", "Windows username for user switching")
sshSftpCmd.Flags().StringVar(&windowsDomain, "windows-domain", "", "Windows domain for user switching")
}
func sftpMain(cmd *cobra.Command, _ []string) error {
return sftpMainDirect(cmd)
}
func sftpMainDirect(cmd *cobra.Command) error {
currentUser, err := user.Current()
if err != nil {
cmd.PrintErrf("failed to get current user: %v\n", err)
os.Exit(sshserver.ExitCodeValidationFail)
}
if windowsUsername != "" {
expectedUsername := windowsUsername
if windowsDomain != "" {
expectedUsername = fmt.Sprintf(`%s\%s`, windowsDomain, windowsUsername)
}
if !strings.EqualFold(currentUser.Username, expectedUsername) && !strings.EqualFold(currentUser.Username, windowsUsername) {
cmd.PrintErrf("user switching failed\n")
os.Exit(sshserver.ExitCodeValidationFail)
}
}
log.Debugf("SFTP process running as: %s (UID: %s, Name: %s)", currentUser.Username, currentUser.Uid, currentUser.Name)
if sftpWorkingDir != "" {
if err := os.Chdir(sftpWorkingDir); err != nil {
cmd.PrintErrf("failed to change to working directory %s: %v\n", sftpWorkingDir, err)
}
}
sftpServer, err := sftp.NewServer(struct {
io.Reader
io.WriteCloser
}{
Reader: os.Stdin,
WriteCloser: os.Stdout,
})
if err != nil {
cmd.PrintErrf("SFTP server creation failed: %v\n", err)
os.Exit(sshserver.ExitCodeShellExecFail)
}
log.Debugf("starting SFTP server")
exitCode := sshserver.ExitCodeSuccess
if err := sftpServer.Serve(); err != nil && !errors.Is(err, io.EOF) {
cmd.PrintErrf("SFTP server error: %v\n", err)
exitCode = sshserver.ExitCodeShellExecFail
}
if err := sftpServer.Close(); err != nil {
log.Debugf("SFTP server close error: %v", err)
}
os.Exit(exitCode)
return nil
}

717
client/cmd/ssh_test.go Normal file
View File

@@ -0,0 +1,717 @@
package cmd
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestSSHCommand_FlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedUser string
expectedPort int
expectedCmd string
expectError bool
}{
{
name: "basic host",
args: []string{"hostname"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "",
},
{
name: "user@host format",
args: []string{"user@hostname"},
expectedHost: "hostname",
expectedUser: "user",
expectedPort: 22,
expectedCmd: "",
},
{
name: "host with command",
args: []string{"hostname", "echo", "hello"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "echo hello",
},
{
name: "command with flags should be preserved",
args: []string{"hostname", "ls", "-la", "/tmp"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "ls -la /tmp",
},
{
name: "double dash separator",
args: []string{"hostname", "--", "ls", "-la"},
expectedHost: "hostname",
expectedUser: "",
expectedPort: 22,
expectedCmd: "-- ls -la",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
// Mock command for testing
cmd := sshCmd
cmd.SetArgs(tt.args)
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
if tt.expectedUser != "" {
assert.Equal(t, tt.expectedUser, username, "username mismatch")
}
assert.Equal(t, tt.expectedPort, port, "port mismatch")
assert.Equal(t, tt.expectedCmd, command, "command mismatch")
})
}
}
func TestSSHCommand_FlagConflictPrevention(t *testing.T) {
// Test that SSH flags don't conflict with command flags
tests := []struct {
name string
args []string
expectedCmd string
description string
}{
{
name: "ls with -la flags",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
description: "ls flags should be passed to remote command",
},
{
name: "grep with -r flag",
args: []string{"hostname", "grep", "-r", "pattern", "/path"},
expectedCmd: "grep -r pattern /path",
description: "grep flags should be passed to remote command",
},
{
name: "ps with aux flags",
args: []string{"hostname", "ps", "aux"},
expectedCmd: "ps aux",
description: "ps flags should be passed to remote command",
},
{
name: "command with double dash",
args: []string{"hostname", "--", "ls", "-la"},
expectedCmd: "-- ls -la",
description: "double dash should be preserved in command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_NonInteractiveExecution(t *testing.T) {
// Test that commands with arguments should execute the command and exit,
// not drop to an interactive shell
tests := []struct {
name string
args []string
expectedCmd string
shouldExit bool
description string
}{
{
name: "ls command should execute and exit",
args: []string{"hostname", "ls"},
expectedCmd: "ls",
shouldExit: true,
description: "ls command should execute and exit, not drop to shell",
},
{
name: "ls with flags should execute and exit",
args: []string{"hostname", "ls", "-la"},
expectedCmd: "ls -la",
shouldExit: true,
description: "ls with flags should execute and exit, not drop to shell",
},
{
name: "pwd command should execute and exit",
args: []string{"hostname", "pwd"},
expectedCmd: "pwd",
shouldExit: true,
description: "pwd command should execute and exit, not drop to shell",
},
{
name: "echo command should execute and exit",
args: []string{"hostname", "echo", "hello"},
expectedCmd: "echo hello",
shouldExit: true,
description: "echo command should execute and exit, not drop to shell",
},
{
name: "no command should open shell",
args: []string{"hostname"},
expectedCmd: "",
shouldExit: false,
description: "no command should open interactive shell",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// When command is present, it should execute the command and exit
// When command is empty, it should open interactive shell
hasCommand := command != ""
assert.Equal(t, tt.shouldExit, hasCommand, "Command presence should match expected behavior")
})
}
}
func TestSSHCommand_FlagHandling(t *testing.T) {
// Test that flags after hostname are not parsed by netbird but passed to SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "ls with -la flag should not be parsed by netbird",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "ls -la should be passed as SSH command, not parsed as netbird flags",
},
{
name: "command with netbird-like flags should be passed through",
args: []string{"hostname", "echo", "--help"},
expectedHost: "hostname",
expectedCmd: "echo --help",
expectError: false,
description: "--help should be passed to echo, not parsed by netbird",
},
{
name: "command with -p flag should not conflict with SSH port flag",
args: []string{"hostname", "ps", "-p", "1234"},
expectedHost: "hostname",
expectedCmd: "ps -p 1234",
expectError: false,
description: "ps -p should be passed to ps command, not parsed as port",
},
{
name: "tar with flags should be passed through",
args: []string{"hostname", "tar", "-czf", "backup.tar.gz", "/home"},
expectedHost: "hostname",
expectedCmd: "tar -czf backup.tar.gz /home",
expectError: false,
description: "tar flags should be passed to tar command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
})
}
}
func TestSSHCommand_RegressionFlagParsing(t *testing.T) {
// Regression test for the specific issue: "sudo ./netbird ssh debian2 ls -la"
// should not parse -la as netbird flags but pass them to the SSH command
tests := []struct {
name string
args []string
expectedHost string
expectedCmd string
expectError bool
description string
}{
{
name: "original issue: ls -la should be preserved",
args: []string{"debian2", "ls", "-la"},
expectedHost: "debian2",
expectedCmd: "ls -la",
expectError: false,
description: "The original failing case should now work",
},
{
name: "ls -l should be preserved",
args: []string{"hostname", "ls", "-l"},
expectedHost: "hostname",
expectedCmd: "ls -l",
expectError: false,
description: "Single letter flags should be preserved",
},
{
name: "SSH port flag should work",
args: []string{"-p", "2222", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedCmd: "ls -la",
expectError: false,
description: "SSH -p flag should be parsed, command flags preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
assert.Equal(t, tt.expectedCmd, command, tt.description)
// Check port for the test case with -p flag
if len(tt.args) > 0 && tt.args[0] == "-p" {
assert.Equal(t, 2222, port, "port should be parsed from -p flag")
}
})
}
}
func TestSSHCommand_PortForwardingFlagParsing(t *testing.T) {
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectError bool
description string
}{
{
name: "local port forwarding -L",
args: []string{"-L", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Single -L flag should be parsed correctly",
},
{
name: "remote port forwarding -R",
args: []string{"-R", "8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectError: false,
description: "Single -R flag should be parsed correctly",
},
{
name: "multiple local port forwards",
args: []string{"-L", "8080:localhost:80", "-L", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80", "9090:localhost:443"},
expectedRemote: []string{},
expectError: false,
description: "Multiple -L flags should be parsed correctly",
},
{
name: "multiple remote port forwards",
args: []string{"-R", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80", "9090:localhost:443"},
expectError: false,
description: "Multiple -R flags should be parsed correctly",
},
{
name: "mixed local and remote forwards",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectError: false,
description: "Mixed -L and -R flags should be parsed correctly",
},
{
name: "port forwarding with bind address",
args: []string{"-L", "127.0.0.1:8080:localhost:80", "hostname"},
expectedHost: "hostname",
expectedLocal: []string{"127.0.0.1:8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with bind address should work",
},
{
name: "port forwarding with command",
args: []string{"-L", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectError: false,
description: "Port forwarding with command should work",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
if tt.expectError {
assert.Error(t, err)
return
}
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
})
}
}
func TestParsePortForward(t *testing.T) {
tests := []struct {
name string
spec string
expectedLocal string
expectedRemote string
expectError bool
description string
}{
{
name: "simple port forward",
spec: "8080:localhost:80",
expectedLocal: "localhost:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Simple port:host:port format should work",
},
{
name: "port forward with bind address",
spec: "127.0.0.1:8080:localhost:80",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "bind_address:port:host:port format should work",
},
{
name: "port forward to different host",
spec: "8080:example.com:443",
expectedLocal: "localhost:8080",
expectedRemote: "example.com:443",
expectError: false,
description: "Forwarding to different host should work",
},
{
name: "port forward with IPv6 (needs bracket support)",
spec: "::1:8080:localhost:80",
expectError: true,
description: "IPv6 without brackets fails as expected (feature to implement)",
},
{
name: "invalid format - too few parts",
spec: "8080:localhost",
expectError: true,
description: "Invalid format with too few parts should fail",
},
{
name: "invalid format - too many parts",
spec: "127.0.0.1:8080:localhost:80:extra",
expectError: true,
description: "Invalid format with too many parts should fail",
},
{
name: "empty spec",
spec: "",
expectError: true,
description: "Empty spec should fail",
},
{
name: "unix socket local forward",
spec: "8080:/tmp/socket",
expectedLocal: "localhost:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket forwarding should work",
},
{
name: "unix socket with bind address",
spec: "127.0.0.1:8080:/tmp/socket",
expectedLocal: "127.0.0.1:8080",
expectedRemote: "/tmp/socket",
expectError: false,
description: "Unix socket with bind address should work",
},
{
name: "wildcard bind all interfaces",
spec: "*:8080:localhost:80",
expectedLocal: "0.0.0.0:8080",
expectedRemote: "localhost:80",
expectError: false,
description: "Wildcard * should bind to all interfaces (0.0.0.0)",
},
{
name: "wildcard for port only",
spec: "8080:*:80",
expectedLocal: "localhost:8080",
expectedRemote: "*:80",
expectError: false,
description: "Wildcard in remote host should be preserved",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
localAddr, remoteAddr, err := parsePortForwardSpec(tt.spec)
if tt.expectError {
assert.Error(t, err, tt.description)
return
}
require.NoError(t, err, tt.description)
assert.Equal(t, tt.expectedLocal, localAddr, tt.description+" - local address")
assert.Equal(t, tt.expectedRemote, remoteAddr, tt.description+" - remote address")
})
}
}
func TestSSHCommand_IntegrationPortForwarding(t *testing.T) {
// Integration test for port forwarding with the actual SSH command implementation
tests := []struct {
name string
args []string
expectedHost string
expectedLocal []string
expectedRemote []string
expectedCmd string
description string
}{
{
name: "local forward with command",
args: []string{"-L", "8080:localhost:80", "hostname", "echo", "test"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{},
expectedCmd: "echo test",
description: "Local forwarding should work with commands",
},
{
name: "remote forward with command",
args: []string{"-R", "8080:localhost:80", "hostname", "ls", "-la"},
expectedHost: "hostname",
expectedLocal: []string{},
expectedRemote: []string{"8080:localhost:80"},
expectedCmd: "ls -la",
description: "Remote forwarding should work with commands",
},
{
name: "multiple forwards with user and command",
args: []string{"-L", "8080:localhost:80", "-R", "9090:localhost:443", "user@hostname", "ps", "aux"},
expectedHost: "hostname",
expectedLocal: []string{"8080:localhost:80"},
expectedRemote: []string{"9090:localhost:443"},
expectedCmd: "ps aux",
description: "Complex case with multiple forwards, user, and command",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
localForwards = nil
remoteForwards = nil
cmd := sshCmd
err := validateSSHArgsWithoutFlagParsing(cmd, tt.args)
require.NoError(t, err, "SSH args validation should succeed for valid input")
assert.Equal(t, tt.expectedHost, host, "host mismatch")
// Handle nil vs empty slice comparison
if len(tt.expectedLocal) == 0 {
assert.True(t, len(localForwards) == 0, tt.description+" - local forwards should be empty")
} else {
assert.Equal(t, tt.expectedLocal, localForwards, tt.description+" - local forwards")
}
if len(tt.expectedRemote) == 0 {
assert.True(t, len(remoteForwards) == 0, tt.description+" - remote forwards should be empty")
} else {
assert.Equal(t, tt.expectedRemote, remoteForwards, tt.description+" - remote forwards")
}
assert.Equal(t, tt.expectedCmd, command, tt.description+" - command")
})
}
}
func TestSSHCommand_ParameterIsolation(t *testing.T) {
tests := []struct {
name string
args []string
expectedCmd string
}{
{
name: "cmd flag passed as command",
args: []string{"hostname", "--cmd", "echo test"},
expectedCmd: "--cmd echo test",
},
{
name: "uid flag passed as command",
args: []string{"hostname", "--uid", "1000"},
expectedCmd: "--uid 1000",
},
{
name: "shell flag passed as command",
args: []string{"hostname", "--shell", "/bin/bash"},
expectedCmd: "--shell /bin/bash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
require.NoError(t, err)
assert.Equal(t, "hostname", host)
assert.Equal(t, tt.expectedCmd, command)
})
}
}
func TestSSHCommand_InvalidFlagRejection(t *testing.T) {
// Test that invalid flags are properly rejected and not misinterpreted as hostnames
tests := []struct {
name string
args []string
description string
}{
{
name: "invalid long flag before hostname",
args: []string{"--invalid-flag", "hostname"},
description: "Invalid flag should return parse error, not treat flag as hostname",
},
{
name: "invalid short flag before hostname",
args: []string{"-x", "hostname"},
description: "Invalid short flag should return parse error",
},
{
name: "invalid flag with value before hostname",
args: []string{"--invalid-option=value", "hostname"},
description: "Invalid flag with value should return parse error",
},
{
name: "typo in known flag",
args: []string{"--por", "2222", "hostname"},
description: "Typo in flag name should return parse error (not silently ignored)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Reset global variables
host = ""
username = ""
port = 22
command = ""
err := validateSSHArgsWithoutFlagParsing(sshCmd, tt.args)
// Should return an error for invalid flags
assert.Error(t, err, tt.description)
// Should not have set host to the invalid flag
assert.NotEqual(t, tt.args[0], host, "Invalid flag should not be interpreted as hostname")
})
}
}

View File

@@ -109,7 +109,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
case yamlFlag:
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
default:
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false, false)
}
if err != nil {

View File

@@ -12,6 +12,7 @@ import (
"google.golang.org/grpc"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -117,7 +118,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock())
accountManager, err := mgmt.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}

View File

@@ -355,6 +355,25 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(serverSSHAllowedFlag).Changed {
req.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
req.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
req.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
req.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
req.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
@@ -439,6 +458,30 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
ic.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
ic.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
ic.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
ic.SSHJWTCacheTTL = &sshJWTCacheTTL
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
@@ -539,6 +582,31 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.ServerSSHAllowed = &serverSSHAllowed
}
if cmd.Flag(enableSSHRootFlag).Changed {
loginRequest.EnableSSHRoot = &enableSSHRoot
}
if cmd.Flag(enableSSHSFTPFlag).Changed {
loginRequest.EnableSSHSFTP = &enableSSHSFTP
}
if cmd.Flag(enableSSHLocalPortForwardFlag).Changed {
loginRequest.EnableSSHLocalPortForwarding = &enableSSHLocalPortForward
}
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(sshJWTCacheTTLFlag).Changed {
sshJWTCacheTTL32 := int32(sshJWTCacheTTL)
loginRequest.SshJWTCacheTTL = &sshJWTCacheTTL32
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}