[client] Add port forwarding to ssh proxy (#5031)

* Implement port forwarding for the ssh proxy

* Allow user switching for port forwarding
This commit is contained in:
Viktor Liu
2026-01-07 12:18:04 +08:00
committed by GitHub
parent 7142d45ef3
commit f012fb8592
15 changed files with 1006 additions and 370 deletions

View File

@@ -2,6 +2,7 @@ package proxy
import (
"context"
"encoding/binary"
"errors"
"fmt"
"io"
@@ -42,6 +43,14 @@ type SSHProxy struct {
conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient
browserOpener func(string) error
mu sync.RWMutex
backendClient *cryptossh.Client
// jwtToken is set once in runProxySSHServer before any handlers are called,
// so concurrent access is safe without additional synchronization.
jwtToken string
forwardedChannelsOnce sync.Once
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
@@ -63,6 +72,17 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browse
}
func (p *SSHProxy) Close() error {
p.mu.Lock()
backendClient := p.backendClient
p.backendClient = nil
p.mu.Unlock()
if backendClient != nil {
if err := backendClient.Close(); err != nil {
log.Debugf("close backend client: %v", err)
}
}
if p.conn != nil {
return p.conn.Close()
}
@@ -77,16 +97,16 @@ func (p *SSHProxy) Connect(ctx context.Context) error {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
log.Debugf("JWT authentication successful, starting proxy to %s:%d", p.targetHost, p.targetPort)
return p.runProxySSHServer(jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
func (p *SSHProxy) runProxySSHServer(jwtToken string) error {
p.jwtToken = jwtToken
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
Handler: p.handleSSHSession,
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
@@ -119,15 +139,20 @@ func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
func (p *SSHProxy) handleSSHSession(session ssh.Session) {
ptyReq, winCh, isPty := session.Pty()
hasCommand := len(session.Command()) > 0
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
sshClient, err := p.getOrCreateBackendClient(session.Context(), session.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
if !isPty && !hasCommand {
p.handleNonInteractiveSession(session, sshClient)
return
}
serverSession, err := sshClient.NewSession()
if err != nil {
@@ -140,7 +165,6 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
if err := serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil); err != nil {
log.Debugf("PTY request to backend: %v", err)
@@ -155,7 +179,7 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
}()
}
if len(session.Command()) > 0 {
if hasCommand {
if err := serverSession.Run(strings.Join(session.Command(), " ")); err != nil {
log.Debugf("run command: %v", err)
p.handleProxyExitCode(session, err)
@@ -176,12 +200,29 @@ func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jw
func (p *SSHProxy) handleProxyExitCode(session ssh.Session, err error) {
var exitErr *cryptossh.ExitError
if errors.As(err, &exitErr) {
if exitErr := session.Exit(exitErr.ExitStatus()); exitErr != nil {
log.Debugf("set exit status: %v", exitErr)
if err := session.Exit(exitErr.ExitStatus()); err != nil {
log.Debugf("set exit status: %v", err)
}
}
}
func (p *SSHProxy) handleNonInteractiveSession(session ssh.Session, sshClient *cryptossh.Client) {
// Create a backend session to mirror the client's session request.
// This keeps the connection alive on the server side while port forwarding channels operate.
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
<-session.Context().Done()
if err := session.Exit(0); err != nil {
log.Debugf("session exit: %v", err)
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
@@ -250,8 +291,52 @@ func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
// directTCPIPHandler handles local port forwarding (direct-tcpip channel).
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, sshCtx ssh.Context) {
var payload struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse direct-tcpip payload: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "invalid payload")
return
}
dest := fmt.Sprintf("%s:%d", payload.DestAddr, payload.DestPort)
log.Debugf("local port forwarding: %s", dest)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for port forwarding: %v\n", err)
_ = newChan.Reject(cryptossh.ConnectionFailed, "backend connection failed")
return
}
backendChan, backendReqs, err := backendClient.OpenChannel("direct-tcpip", newChan.ExtraData())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "open backend channel for %s: %v\n", dest, err)
var openErr *cryptossh.OpenChannelError
if errors.As(err, &openErr) {
_ = newChan.Reject(openErr.Reason, openErr.Message)
} else {
_ = newChan.Reject(cryptossh.ConnectionFailed, err.Error())
}
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := newChan.Accept()
if err != nil {
log.Debugf("local port forwarding: accept channel: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
@@ -354,12 +439,143 @@ func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.Wr
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
// tcpipForwardHandler handles remote port forwarding (tcpip-forward request).
func (p *SSHProxy) tcpipForwardHandler(sshCtx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient, err := p.getOrCreateBackendClient(sshCtx, sshCtx.User())
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "backend connection for remote port forwarding: %v\n", err)
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "forward tcpip-forward request for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
if ok {
actualPort := reqPayload.Port
if reqPayload.Port == 0 && len(payload) >= 4 {
actualPort = binary.BigEndian.Uint32(payload)
}
log.Debugf("remote port forwarding established for %s:%d", reqPayload.Host, actualPort)
p.forwardedChannelsOnce.Do(func() {
go p.handleForwardedChannels(sshCtx, backendClient)
})
}
return ok, payload
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
// cancelTcpipForwardHandler handles cancel-tcpip-forward request.
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
var reqPayload struct {
Host string
Port uint32
}
if err := cryptossh.Unmarshal(req.Payload, &reqPayload); err != nil {
_, _ = fmt.Fprintf(p.stderr, "parse cancel-tcpip-forward payload: %v\n", err)
return false, nil
}
log.Debugf("cancel-tcpip-forward request for %s:%d", reqPayload.Host, reqPayload.Port)
backendClient := p.getBackendClient()
if backendClient == nil {
return false, nil
}
ok, payload, err := backendClient.SendRequest(req.Type, req.WantReply, req.Payload)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "cancel-tcpip-forward for %s:%d: %v\n", reqPayload.Host, reqPayload.Port, err)
return false, nil
}
return ok, payload
}
// getOrCreateBackendClient returns the existing backend client or creates a new one.
func (p *SSHProxy) getOrCreateBackendClient(ctx context.Context, user string) (*cryptossh.Client, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.backendClient != nil {
return p.backendClient, nil
}
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
log.Debugf("connecting to backend %s", targetAddr)
client, err := p.dialBackend(ctx, targetAddr, user, p.jwtToken)
if err != nil {
return nil, err
}
log.Debugf("backend connection established to %s", targetAddr)
p.backendClient = client
return client, nil
}
// getBackendClient returns the existing backend client or nil.
func (p *SSHProxy) getBackendClient() *cryptossh.Client {
p.mu.RLock()
defer p.mu.RUnlock()
return p.backendClient
}
// handleForwardedChannels handles forwarded-tcpip channels from the backend for remote port forwarding.
// When the backend receives incoming connections on the forwarded port, it sends them as
// "forwarded-tcpip" channels which we need to proxy to the client.
func (p *SSHProxy) handleForwardedChannels(sshCtx ssh.Context, backendClient *cryptossh.Client) {
sshConn, ok := sshCtx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
if !ok || sshConn == nil {
log.Debugf("no SSH connection in context for forwarded channels")
return
}
channelChan := backendClient.HandleChannelOpen("forwarded-tcpip")
for {
select {
case <-sshCtx.Done():
return
case newChannel, ok := <-channelChan:
if !ok {
return
}
go p.handleForwardedChannel(sshCtx, sshConn, newChannel)
}
}
}
// handleForwardedChannel handles a single forwarded-tcpip channel from the backend.
func (p *SSHProxy) handleForwardedChannel(sshCtx ssh.Context, sshConn *cryptossh.ServerConn, newChannel cryptossh.NewChannel) {
backendChan, backendReqs, err := newChannel.Accept()
if err != nil {
log.Debugf("remote port forwarding: accept from backend: %v", err)
return
}
go cryptossh.DiscardRequests(backendReqs)
clientChan, clientReqs, err := sshConn.OpenChannel("forwarded-tcpip", newChannel.ExtraData())
if err != nil {
log.Debugf("remote port forwarding: open to client: %v", err)
_ = backendChan.Close()
return
}
go cryptossh.DiscardRequests(clientReqs)
nbssh.BidirectionalCopyWithContext(log.NewEntry(log.StandardLogger()), sshCtx, clientChan, backendChan)
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {