[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

392
client/ssh/proxy/proxy.go Normal file
View File

@@ -0,0 +1,392 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/version"
)
const (
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
sshConnectionTimeout = 120 * time.Second
// sshHandshakeTimeout is the timeout for SSH handshake completion
sshHandshakeTimeout = 30 * time.Second
jwtAuthErrorMsg = "JWT authentication: %w"
)
type SSHProxy struct {
daemonAddr string
targetHost string
targetPort int
stderr io.Writer
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("connect to daemon: %w", err)
}
return &SSHProxy{
daemonAddr: daemonAddr,
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
conn: grpcConn,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}, nil
}
func (p *SSHProxy) Close() error {
if p.conn != nil {
return p.conn.Close()
}
return nil
}
func (p *SSHProxy) Connect(ctx context.Context) error {
hint := profilemanager.GetLoginHint()
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint)
if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": func(s ssh.Session) {
p.sftpSubsystemHandler(s, jwtToken)
},
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": p.tcpipForwardHandler,
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
},
Version: serverVersion,
}
hostKey, err := generateHostKey()
if err != nil {
return fmt.Errorf("generate host key: %w", err)
}
sshServer.HostSigners = []ssh.Signer{hostKey}
conn := &stdioConn{
stdin: os.Stdin,
stdout: os.Stdout,
}
sshServer.HandleConn(conn)
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
log.Debugf("PTY request to backend: %v", err)
}
go func() {
for win := range winCh {
if err := serverSession.WindowChange(win.Height, win.Width); err != nil {
log.Debugf("window change: %v", err)
}
}
}()
}
if len(session.Command()) > 0 {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
}
return
}
if err = serverSession.Shell(); err != nil {
log.Debugf("start shell: %v", err)
return
}
if err := serverSession.Wait(); err != nil {
log.Debugf("session wait: %v", err)
p.handleProxyExitCode(session, err)
}
}
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
log.Debugf("set exit status: %v", exitErr)
}
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
return nil, fmt.Errorf("generate ED25519 key: %w", err)
}
signer, err := cryptossh.ParsePrivateKey(keyPEM)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return signer, nil
}
type stdioConn struct {
stdin io.Reader
stdout io.Writer
closed bool
mu sync.Mutex
}
func (c *stdioConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.EOF
}
c.mu.Unlock()
return c.stdin.Read(b)
}
func (c *stdioConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.ErrClosedPipe
}
c.mu.Unlock()
return c.stdout.Write(b)
}
func (c *stdioConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func (c *stdioConn) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
ctx, cancel := context.WithCancel(s.Context())
defer cancel()
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := sshClient.Close(); err != nil {
log.Debugf("close SSH client: %v", err)
}
}()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := serverSession.Close(); err != nil {
log.Debugf("close server session: %v", err)
}
}()
stdin, stdout, err := p.setupSFTPPipes(serverSession)
if err != nil {
log.Debugf("setup SFTP pipes: %v", err)
_ = s.Exit(1)
return
}
if err := serverSession.RequestSubsystem("sftp"); err != nil {
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
_ = s.Exit(1)
return
}
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
}
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
stdin, err := serverSession.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
}
stdout, err := serverSession.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
}
return stdin, stdout, nil
}
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
copyErrCh := make(chan error, 2)
go func() {
_, err := io.Copy(stdin, s)
if err != nil {
log.Debugf("SFTP client to server copy: %v", err)
}
if err := stdin.Close(); err != nil {
log.Debugf("close stdin: %v", err)
}
copyErrCh <- err
}()
go func() {
_, err := io.Copy(s, stdout)
if err != nil {
log.Debugf("SFTP server to client copy: %v", err)
}
copyErrCh <- err
}()
go func() {
<-ctx.Done()
if err := serverSession.Close(); err != nil {
log.Debugf("force close server session on context cancellation: %v", err)
}
}()
for i := 0; i < 2; i++ {
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
log.Debugf("SFTP copy error: %v", err)
}
}
if err := serverSession.Wait(); err != nil {
log.Debugf("SFTP session ended: %v", err)
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
config := &cryptossh.ClientConfig{
User: user,
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
Timeout: sshHandshakeTimeout,
HostKeyCallback: p.verifyHostKey,
}
dialer := &net.Dialer{
Timeout: sshConnectionTimeout,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to server: %w", err)
}
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("SSH handshake: %w", err)
}
return cryptossh.NewClient(clientConn, chans, reqs), nil
}
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}

View File

@@ -0,0 +1,367 @@
package proxy
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"runtime"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
)
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "ssh" {
if os.Args[2] == "exec" {
if len(os.Args) > 3 {
cmd := os.Args[3]
if cmd == "echo" && len(os.Args) > 4 {
fmt.Fprintln(os.Stdout, os.Args[4])
os.Exit(0)
}
}
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
os.Exit(1)
}
}
code := m.Run()
testutil.CleanupTestUsers()
os.Exit(code)
}
func TestSSHProxy_verifyHostKey(t *testing.T) {
t.Run("calls daemon to verify host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testPubKey, err := nbssh.GeneratePublicKey(testKey)
require.NoError(t, err)
mockDaemon.setHostKey("test-host", testPubKey)
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
assert.NoError(t, err)
})
t.Run("rejects unknown host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
require.NoError(t, err)
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
assert.Error(t, err)
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
// TODO: Windows test times out - user switching and command execution tested on Linux
if runtime.GOOS == "windows" {
t.Skip("Skipping on Windows - covered by Linux tests")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
jwtToken string
}
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
key, found := m.hostKeys[req.PeerAddress]
return &proto.GetPeerSSHHostKeyResponse{
Found: found,
SshHostKey: key,
}, nil
}
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
return &proto.RequestJWTAuthResponse{
CachedToken: m.jwtToken,
}, nil
}
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
return &proto.WaitJWTTokenResponse{
Token: m.jwtToken,
}, nil
}
type mockDaemon struct {
addr string
server *grpc.Server
impl *mockDaemonServer
}
func startMockDaemon(t *testing.T) *mockDaemon {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
impl := &mockDaemonServer{
hostKeys: make(map[string][]byte),
jwtToken: "test-jwt-token",
}
grpcServer := grpc.NewServer()
proto.RegisterDaemonServiceServer(grpcServer, impl)
go func() {
_ = grpcServer.Serve(listener)
}()
return &mockDaemon{
addr: listener.Addr().String(),
server: grpcServer,
impl: impl,
}
}
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
}
}
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
t.Helper()
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}