mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 09:16:40 +00:00
[client] Support non-PTY no-command interactive SSH sessions (#5093)
This commit is contained in:
@@ -4,12 +4,15 @@ import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -23,25 +26,67 @@ import (
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
// On platforms where su doesn't support --pty (macOS, FreeBSD, Windows), the SSH server
|
||||
// spawns an executor subprocess via os.Executable(). During tests, this invokes the test
|
||||
// binary with "ssh exec" args. We handle that here to properly execute commands and
|
||||
// propagate exit codes.
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
runTestExecutor()
|
||||
return
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// runTestExecutor emulates the netbird executor for tests.
|
||||
// Parses --shell and --cmd args, runs the command, and exits with the correct code.
|
||||
func runTestExecutor() {
|
||||
if os.Getenv("_NETBIRD_TEST_EXECUTOR") != "" {
|
||||
fmt.Fprintf(os.Stderr, "executor recursion detected\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Setenv("_NETBIRD_TEST_EXECUTOR", "1")
|
||||
|
||||
shell := "/bin/sh"
|
||||
var command string
|
||||
for i := 3; i < len(os.Args); i++ {
|
||||
switch os.Args[i] {
|
||||
case "--shell":
|
||||
if i+1 < len(os.Args) {
|
||||
shell = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
case "--cmd":
|
||||
if i+1 < len(os.Args) {
|
||||
command = os.Args[i+1]
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var cmd *exec.Cmd
|
||||
if command == "" {
|
||||
cmd = exec.Command(shell)
|
||||
} else {
|
||||
cmd = exec.Command(shell, "-c", command)
|
||||
}
|
||||
cmd.Args[0] = "-" + filepath.Base(shell)
|
||||
cmd.Stdin = os.Stdin
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
|
||||
if err := cmd.Run(); err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
os.Exit(exitErr.ExitCode())
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||
func TestSSHServerCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
@@ -405,6 +450,171 @@ func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
||||
return createTempKeyFileFromBytes(t, privateKey)
|
||||
}
|
||||
|
||||
// TestSSHPtyModes tests different PTY allocation modes (-T, -t, -tt flags)
|
||||
// This ensures our implementation matches OpenSSH behavior for:
|
||||
// - ssh host command (no PTY - default when no TTY)
|
||||
// - ssh -T host command (explicit no PTY)
|
||||
// - ssh -t host command (force PTY)
|
||||
// - ssh -T host (no PTY shell - our implementation)
|
||||
func TestSSHPtyModes(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH PTY mode tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH PTY tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
baseArgs := []string{
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "BatchMode=yes",
|
||||
}
|
||||
|
||||
t.Run("command_default_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), fmt.Sprintf("%s@%s", username, host), "echo", "no_pty_default")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (default no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "no_pty_default")
|
||||
})
|
||||
|
||||
t.Run("command_explicit_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "echo", "explicit_no_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-T explicit no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "explicit_no_pty")
|
||||
})
|
||||
|
||||
t.Run("command_force_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "echo", "force_pty")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "Command (-tt force PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "force_pty")
|
||||
})
|
||||
|
||||
t.Run("shell_explicit_no_pty", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host))
|
||||
cmd := exec.CommandContext(ctx, "ssh", args...)
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, cmd.Start(), "Shell (-T no PTY) start failed")
|
||||
|
||||
go func() {
|
||||
defer stdin.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err := stdin.Write([]byte("echo shell_no_pty_test\n"))
|
||||
assert.NoError(t, err, "write echo command")
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
_, err = stdin.Write([]byte("exit 0\n"))
|
||||
assert.NoError(t, err, "write exit command")
|
||||
}()
|
||||
|
||||
output, _ := io.ReadAll(stdout)
|
||||
err = cmd.Wait()
|
||||
|
||||
require.NoError(t, err, "Shell (-T no PTY) failed: %s", output)
|
||||
assert.Contains(t, string(output), "shell_no_pty_test")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host), "exit", "42")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "Command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 42, exitErr.ExitCode(), "Exit code should be preserved with -T")
|
||||
})
|
||||
|
||||
t.Run("exit_code_preserved_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host), "sh -c 'exit 43'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
err := cmd.Run()
|
||||
require.Error(t, err, "PTY command should exit with non-zero")
|
||||
|
||||
var exitErr *exec.ExitError
|
||||
require.True(t, errors.As(err, &exitErr), "Should be an exit error: %v", err)
|
||||
assert.Equal(t, 43, exitErr.ExitCode(), "Exit code should be preserved with -tt")
|
||||
})
|
||||
|
||||
t.Run("stderr_works_no_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-T", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
var stdout, stderr strings.Builder
|
||||
cmd.Stdout = &stdout
|
||||
cmd.Stderr = &stderr
|
||||
|
||||
require.NoError(t, cmd.Run(), "stderr test failed")
|
||||
assert.Contains(t, stdout.String(), "stdout_msg", "stdout should have stdout_msg")
|
||||
assert.Contains(t, stderr.String(), "stderr_msg", "stderr should have stderr_msg")
|
||||
assert.NotContains(t, stdout.String(), "stderr_msg", "stdout should NOT have stderr_msg")
|
||||
})
|
||||
|
||||
t.Run("stderr_merged_with_pty", func(t *testing.T) {
|
||||
args := append(slices.Clone(baseArgs), "-tt", fmt.Sprintf("%s@%s", username, host),
|
||||
"sh -c 'echo stdout_msg; echo stderr_msg >&2'")
|
||||
cmd := exec.Command("ssh", args...)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
require.NoError(t, err, "PTY stderr test failed: %s", output)
|
||||
assert.Contains(t, string(output), "stdout_msg")
|
||||
assert.Contains(t, string(output), "stderr_msg")
|
||||
})
|
||||
}
|
||||
|
||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
|
||||
Reference in New Issue
Block a user