mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Fix SSH proxy stripping shell quoting from forwarded commands (#5669)
This commit is contained in:
@@ -141,7 +141,7 @@ func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
|
||||
if err != nil {
|
||||
@@ -180,7 +180,7 @@ func (p *SSHProxy) handleSSHSession(session ssh.Session) {
|
||||
}
|
||||
|
||||
if hasCommand {
|
||||
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
|
||||
if err := serverSession.Run(session.RawCommand()); err != nil {
|
||||
log.Debugf("run command: %v", err)
|
||||
p.handleProxyExitCode(session, err)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
@@ -245,6 +246,191 @@ func TestSSHProxy_Connect(t *testing.T) {
|
||||
cancel()
|
||||
}
|
||||
|
||||
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
||||
// when forwarding commands to the backend. This is critical for tools like
|
||||
// Ansible that send commands such as:
|
||||
//
|
||||
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
||||
//
|
||||
// The single quotes must be preserved so the backend shell receives the
|
||||
// subshell expression as a single argument to -c.
|
||||
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
sshClient, cleanup := setupProxySSHClient(t)
|
||||
defer cleanup()
|
||||
|
||||
// These commands simulate what the SSH protocol delivers as exec payloads.
|
||||
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
||||
// the local shell strips the outer single quotes, and the SSH exec request
|
||||
// contains the raw string: /bin/sh -c "( echo hello )"
|
||||
//
|
||||
// The proxy must forward this string verbatim. Using session.Command()
|
||||
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
||||
// the command on the backend.
|
||||
tests := []struct {
|
||||
name string
|
||||
command string
|
||||
expect string
|
||||
}{
|
||||
{
|
||||
name: "subshell_in_double_quotes",
|
||||
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
||||
expect: "from-subshell\nouter\n",
|
||||
},
|
||||
{
|
||||
name: "printf_with_special_chars",
|
||||
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
||||
expect: "hello world\n",
|
||||
},
|
||||
{
|
||||
name: "nested_command_substitution",
|
||||
command: `/bin/sh -c "echo $(echo nested)"`,
|
||||
expect: "nested\n",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = session.Close() }()
|
||||
|
||||
var stderrBuf bytes.Buffer
|
||||
session.Stderr = &stderrBuf
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output(tc.command)
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
if stderrBuf.Len() > 0 {
|
||||
t.Logf("stderr: %s", stderrBuf.String())
|
||||
}
|
||||
require.NoError(t, err, "command should succeed: %s", tc.command)
|
||||
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("command timed out: %s", tc.command)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// setupProxySSHClient creates a full proxy test environment and returns
|
||||
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
||||
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
||||
t.Helper()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audiences: []string{audience},
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
testUsername := testutil.GetTestUsername(t)
|
||||
testJWTUser := "test-username"
|
||||
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
||||
require.NoError(t, err)
|
||||
|
||||
authConfig := &sshauth.Config{
|
||||
UserIDClaim: sshauth.DefaultUserIDClaim,
|
||||
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
||||
MachineUsers: map[string][]uint32{
|
||||
testUsername: {0},
|
||||
},
|
||||
}
|
||||
sshServer.UpdateSSHAuth(authConfig)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
|
||||
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
||||
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
|
||||
go func() {
|
||||
_ = proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
cleanupFn := func() {
|
||||
_ = client.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
_ = sshServer.Stop()
|
||||
mockDaemon.stop()
|
||||
jwksServer.Close()
|
||||
}
|
||||
|
||||
return client, cleanupFn
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
|
||||
@@ -60,7 +60,7 @@ func (s *Server) sessionHandler(session ssh.Session) {
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
hasCommand := session.RawCommand() != ""
|
||||
|
||||
if isPty && !hasCommand {
|
||||
// ssh <host> - PTY interactive session (login)
|
||||
|
||||
Reference in New Issue
Block a user