mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-15 23:06:38 +00:00
571 lines
15 KiB
Go
571 lines
15 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"math/big"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"runtime"
|
|
"strconv"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
cryptossh "golang.org/x/crypto/ssh"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
|
|
"github.com/netbirdio/netbird/client/proto"
|
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
|
"github.com/netbirdio/netbird/client/ssh/server"
|
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
|
sshuserhash "github.com/netbirdio/netbird/shared/sshauth"
|
|
)
|
|
|
|
func TestMain(m *testing.M) {
|
|
if len(os.Args) > 2 && os.Args[1] == "ssh" {
|
|
if os.Args[2] == "exec" {
|
|
if len(os.Args) > 3 {
|
|
cmd := os.Args[3]
|
|
if cmd == "echo" && len(os.Args) > 4 {
|
|
fmt.Fprintln(os.Stdout, os.Args[4])
|
|
os.Exit(0)
|
|
}
|
|
}
|
|
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
code := m.Run()
|
|
|
|
testutil.CleanupTestUsers()
|
|
|
|
os.Exit(code)
|
|
}
|
|
|
|
func TestSSHProxy_verifyHostKey(t *testing.T) {
|
|
t.Run("calls daemon to verify host key", func(t *testing.T) {
|
|
mockDaemon := startMockDaemon(t)
|
|
defer mockDaemon.stop()
|
|
|
|
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
require.NoError(t, err)
|
|
defer func() { _ = grpcConn.Close() }()
|
|
|
|
proxy := &SSHProxy{
|
|
daemonAddr: mockDaemon.addr,
|
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
|
}
|
|
|
|
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
testPubKey, err := nbssh.GeneratePublicKey(testKey)
|
|
require.NoError(t, err)
|
|
|
|
mockDaemon.setHostKey("test-host", testPubKey)
|
|
|
|
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
|
|
assert.NoError(t, err)
|
|
})
|
|
|
|
t.Run("rejects unknown host key", func(t *testing.T) {
|
|
mockDaemon := startMockDaemon(t)
|
|
defer mockDaemon.stop()
|
|
|
|
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
|
require.NoError(t, err)
|
|
defer func() { _ = grpcConn.Close() }()
|
|
|
|
proxy := &SSHProxy{
|
|
daemonAddr: mockDaemon.addr,
|
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
|
}
|
|
|
|
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
|
|
require.NoError(t, err)
|
|
|
|
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
|
|
assert.Error(t, err)
|
|
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
|
|
})
|
|
}
|
|
|
|
func TestSSHProxy_Connect(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping integration test in short mode")
|
|
}
|
|
|
|
// TODO: Windows test times out - user switching and command execution tested on Linux
|
|
if runtime.GOOS == "windows" {
|
|
t.Skip("Skipping on Windows - covered by Linux tests")
|
|
}
|
|
|
|
const (
|
|
issuer = "https://test-issuer.example.com"
|
|
audience = "test-audience"
|
|
)
|
|
|
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
|
defer jwksServer.Close()
|
|
|
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
|
require.NoError(t, err)
|
|
|
|
serverConfig := &server.Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: &server.JWTConfig{
|
|
Issuer: issuer,
|
|
Audiences: []string{audience},
|
|
KeysLocation: jwksURL,
|
|
},
|
|
}
|
|
sshServer := server.New(serverConfig)
|
|
sshServer.SetAllowRootLogin(true)
|
|
|
|
// Configure SSH authorization for the test user
|
|
testUsername := testutil.GetTestUsername(t)
|
|
testJWTUser := "test-username"
|
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
|
require.NoError(t, err)
|
|
|
|
authConfig := &sshauth.Config{
|
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
|
MachineUsers: map[string][]uint32{
|
|
testUsername: {0}, // Index 0 in AuthorizedUsers
|
|
},
|
|
}
|
|
sshServer.UpdateSSHAuth(authConfig)
|
|
|
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
|
defer func() { _ = sshServer.Stop() }()
|
|
|
|
mockDaemon := startMockDaemon(t)
|
|
defer mockDaemon.stop()
|
|
|
|
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
|
require.NoError(t, err)
|
|
port, err := strconv.Atoi(portStr)
|
|
require.NoError(t, err)
|
|
|
|
mockDaemon.setHostKey(host, hostPubKey)
|
|
|
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
|
mockDaemon.setJWTToken(validToken)
|
|
|
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
|
require.NoError(t, err)
|
|
|
|
clientConn, proxyConn := net.Pipe()
|
|
defer func() { _ = clientConn.Close() }()
|
|
|
|
origStdin := os.Stdin
|
|
origStdout := os.Stdout
|
|
defer func() {
|
|
os.Stdin = origStdin
|
|
os.Stdout = origStdout
|
|
}()
|
|
|
|
stdinReader, stdinWriter, err := os.Pipe()
|
|
require.NoError(t, err)
|
|
stdoutReader, stdoutWriter, err := os.Pipe()
|
|
require.NoError(t, err)
|
|
|
|
os.Stdin = stdinReader
|
|
os.Stdout = stdoutWriter
|
|
|
|
go func() {
|
|
_, _ = io.Copy(stdinWriter, proxyConn)
|
|
}()
|
|
go func() {
|
|
_, _ = io.Copy(proxyConn, stdoutReader)
|
|
}()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
defer cancel()
|
|
|
|
connectErrCh := make(chan error, 1)
|
|
go func() {
|
|
connectErrCh <- proxyInstance.Connect(ctx)
|
|
}()
|
|
|
|
sshConfig := &cryptossh.ClientConfig{
|
|
User: testutil.GetTestUsername(t),
|
|
Auth: []cryptossh.AuthMethod{},
|
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
|
Timeout: 3 * time.Second,
|
|
}
|
|
|
|
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
|
require.NoError(t, err, "Should connect to proxy server")
|
|
defer func() { _ = sshClientConn.Close() }()
|
|
|
|
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
|
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err, "Should create session through full proxy to backend")
|
|
|
|
outputCh := make(chan []byte, 1)
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
output, err := session.Output("echo hello-from-proxy")
|
|
outputCh <- output
|
|
errCh <- err
|
|
}()
|
|
|
|
select {
|
|
case output := <-outputCh:
|
|
err := <-errCh
|
|
require.NoError(t, err, "Command should execute successfully through proxy")
|
|
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
|
case <-time.After(3 * time.Second):
|
|
t.Fatal("Command execution timed out")
|
|
}
|
|
|
|
_ = session.Close()
|
|
_ = sshClient.Close()
|
|
_ = clientConn.Close()
|
|
cancel()
|
|
}
|
|
|
|
// TestSSHProxy_CommandQuoting verifies that the proxy preserves shell quoting
|
|
// when forwarding commands to the backend. This is critical for tools like
|
|
// Ansible that send commands such as:
|
|
//
|
|
// /bin/sh -c '( umask 77 && mkdir -p ... ) && sleep 0'
|
|
//
|
|
// The single quotes must be preserved so the backend shell receives the
|
|
// subshell expression as a single argument to -c.
|
|
func TestSSHProxy_CommandQuoting(t *testing.T) {
|
|
if testing.Short() {
|
|
t.Skip("Skipping integration test in short mode")
|
|
}
|
|
|
|
sshClient, cleanup := setupProxySSHClient(t)
|
|
defer cleanup()
|
|
|
|
// These commands simulate what the SSH protocol delivers as exec payloads.
|
|
// When a user types: ssh host '/bin/sh -c "( echo hello )"'
|
|
// the local shell strips the outer single quotes, and the SSH exec request
|
|
// contains the raw string: /bin/sh -c "( echo hello )"
|
|
//
|
|
// The proxy must forward this string verbatim. Using session.Command()
|
|
// (shlex.Split + strings.Join) strips the inner double quotes, breaking
|
|
// the command on the backend.
|
|
tests := []struct {
|
|
name string
|
|
command string
|
|
expect string
|
|
}{
|
|
{
|
|
name: "subshell_in_double_quotes",
|
|
command: `/bin/sh -c "( echo from-subshell ) && echo outer"`,
|
|
expect: "from-subshell\nouter\n",
|
|
},
|
|
{
|
|
name: "printf_with_special_chars",
|
|
command: `/bin/sh -c "printf '%s\n' 'hello world'"`,
|
|
expect: "hello world\n",
|
|
},
|
|
{
|
|
name: "nested_command_substitution",
|
|
command: `/bin/sh -c "echo $(echo nested)"`,
|
|
expect: "nested\n",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
session, err := sshClient.NewSession()
|
|
require.NoError(t, err)
|
|
defer func() { _ = session.Close() }()
|
|
|
|
var stderrBuf bytes.Buffer
|
|
session.Stderr = &stderrBuf
|
|
|
|
outputCh := make(chan []byte, 1)
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
output, err := session.Output(tc.command)
|
|
outputCh <- output
|
|
errCh <- err
|
|
}()
|
|
|
|
select {
|
|
case output := <-outputCh:
|
|
err := <-errCh
|
|
if stderrBuf.Len() > 0 {
|
|
t.Logf("stderr: %s", stderrBuf.String())
|
|
}
|
|
require.NoError(t, err, "command should succeed: %s", tc.command)
|
|
assert.Equal(t, tc.expect, string(output), "output mismatch for: %s", tc.command)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatalf("command timed out: %s", tc.command)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// setupProxySSHClient creates a full proxy test environment and returns
|
|
// an SSH client connected through the proxy to a backend NetBird SSH server.
|
|
func setupProxySSHClient(t *testing.T) (*cryptossh.Client, func()) {
|
|
t.Helper()
|
|
|
|
const (
|
|
issuer = "https://test-issuer.example.com"
|
|
audience = "test-audience"
|
|
)
|
|
|
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
|
|
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
require.NoError(t, err)
|
|
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
|
require.NoError(t, err)
|
|
|
|
serverConfig := &server.Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: &server.JWTConfig{
|
|
Issuer: issuer,
|
|
Audiences: []string{audience},
|
|
KeysLocation: jwksURL,
|
|
},
|
|
}
|
|
sshServer := server.New(serverConfig)
|
|
sshServer.SetAllowRootLogin(true)
|
|
|
|
testUsername := testutil.GetTestUsername(t)
|
|
testJWTUser := "test-username"
|
|
testUserHash, err := sshuserhash.HashUserID(testJWTUser)
|
|
require.NoError(t, err)
|
|
|
|
authConfig := &sshauth.Config{
|
|
UserIDClaim: sshauth.DefaultUserIDClaim,
|
|
AuthorizedUsers: []sshuserhash.UserIDHash{testUserHash},
|
|
MachineUsers: map[string][]uint32{
|
|
testUsername: {0},
|
|
},
|
|
}
|
|
sshServer.UpdateSSHAuth(authConfig)
|
|
|
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
|
|
|
mockDaemon := startMockDaemon(t)
|
|
|
|
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
|
require.NoError(t, err)
|
|
port, err := strconv.Atoi(portStr)
|
|
require.NoError(t, err)
|
|
|
|
mockDaemon.setHostKey(host, hostPubKey)
|
|
|
|
validToken := generateValidJWT(t, privateKey, issuer, audience, testJWTUser)
|
|
mockDaemon.setJWTToken(validToken)
|
|
|
|
proxyInstance, err := New(mockDaemon.addr, host, port, io.Discard, nil)
|
|
require.NoError(t, err)
|
|
|
|
origStdin := os.Stdin
|
|
origStdout := os.Stdout
|
|
|
|
stdinReader, stdinWriter, err := os.Pipe()
|
|
require.NoError(t, err)
|
|
stdoutReader, stdoutWriter, err := os.Pipe()
|
|
require.NoError(t, err)
|
|
|
|
os.Stdin = stdinReader
|
|
os.Stdout = stdoutWriter
|
|
|
|
clientConn, proxyConn := net.Pipe()
|
|
|
|
go func() { _, _ = io.Copy(stdinWriter, proxyConn) }()
|
|
go func() { _, _ = io.Copy(proxyConn, stdoutReader) }()
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
|
|
|
go func() {
|
|
_ = proxyInstance.Connect(ctx)
|
|
}()
|
|
|
|
sshConfig := &cryptossh.ClientConfig{
|
|
User: testutil.GetTestUsername(t),
|
|
Auth: []cryptossh.AuthMethod{},
|
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
|
require.NoError(t, err)
|
|
|
|
client := cryptossh.NewClient(sshClientConn, chans, reqs)
|
|
|
|
cleanupFn := func() {
|
|
_ = client.Close()
|
|
_ = clientConn.Close()
|
|
cancel()
|
|
os.Stdin = origStdin
|
|
os.Stdout = origStdout
|
|
_ = sshServer.Stop()
|
|
mockDaemon.stop()
|
|
jwksServer.Close()
|
|
}
|
|
|
|
return client, cleanupFn
|
|
}
|
|
|
|
type mockDaemonServer struct {
|
|
proto.UnimplementedDaemonServiceServer
|
|
hostKeys map[string][]byte
|
|
jwtToken string
|
|
}
|
|
|
|
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
|
|
key, found := m.hostKeys[req.PeerAddress]
|
|
return &proto.GetPeerSSHHostKeyResponse{
|
|
Found: found,
|
|
SshHostKey: key,
|
|
}, nil
|
|
}
|
|
|
|
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
|
|
return &proto.RequestJWTAuthResponse{
|
|
CachedToken: m.jwtToken,
|
|
}, nil
|
|
}
|
|
|
|
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
|
|
return &proto.WaitJWTTokenResponse{
|
|
Token: m.jwtToken,
|
|
}, nil
|
|
}
|
|
|
|
type mockDaemon struct {
|
|
addr string
|
|
server *grpc.Server
|
|
impl *mockDaemonServer
|
|
}
|
|
|
|
func startMockDaemon(t *testing.T) *mockDaemon {
|
|
t.Helper()
|
|
|
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
impl := &mockDaemonServer{
|
|
hostKeys: make(map[string][]byte),
|
|
jwtToken: "test-jwt-token",
|
|
}
|
|
|
|
grpcServer := grpc.NewServer()
|
|
proto.RegisterDaemonServiceServer(grpcServer, impl)
|
|
|
|
go func() {
|
|
_ = grpcServer.Serve(listener)
|
|
}()
|
|
|
|
return &mockDaemon{
|
|
addr: listener.Addr().String(),
|
|
server: grpcServer,
|
|
impl: impl,
|
|
}
|
|
}
|
|
|
|
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
|
m.impl.hostKeys[addr] = pubKey
|
|
}
|
|
|
|
func (m *mockDaemon) setJWTToken(token string) {
|
|
m.impl.jwtToken = token
|
|
}
|
|
|
|
func (m *mockDaemon) stop() {
|
|
if m.server != nil {
|
|
m.server.Stop()
|
|
}
|
|
}
|
|
|
|
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
|
t.Helper()
|
|
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
|
|
require.NoError(t, err)
|
|
return pubKey
|
|
}
|
|
|
|
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
|
t.Helper()
|
|
privateKey, jwksJSON := generateTestJWKS(t)
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if _, err := w.Write(jwksJSON); err != nil {
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}))
|
|
|
|
return server, privateKey, server.URL
|
|
}
|
|
|
|
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
|
t.Helper()
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
require.NoError(t, err)
|
|
|
|
publicKey := &privateKey.PublicKey
|
|
n := publicKey.N.Bytes()
|
|
e := publicKey.E
|
|
|
|
jwk := nbjwt.JSONWebKey{
|
|
Kty: "RSA",
|
|
Kid: "test-key-id",
|
|
Use: "sig",
|
|
N: base64.RawURLEncoding.EncodeToString(n),
|
|
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
|
}
|
|
|
|
jwks := nbjwt.Jwks{
|
|
Keys: []nbjwt.JSONWebKey{jwk},
|
|
}
|
|
|
|
jwksJSON, err := json.Marshal(jwks)
|
|
require.NoError(t, err)
|
|
|
|
return privateKey, jwksJSON
|
|
}
|
|
|
|
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string, user string) string {
|
|
t.Helper()
|
|
claims := jwt.MapClaims{
|
|
"iss": issuer,
|
|
"aud": audience,
|
|
"sub": user,
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
"iat": time.Now().Unix(),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
|
token.Header["kid"] = "test-key-id"
|
|
|
|
tokenString, err := token.SignedString(privateKey)
|
|
require.NoError(t, err)
|
|
|
|
return tokenString
|
|
}
|