[client] Add port forwarding to ssh proxy (#5031)

* Implement port forwarding for the ssh proxy

* Allow user switching for port forwarding
This commit is contained in:
Viktor Liu
2026-01-07 12:18:04 +08:00
committed by GitHub
parent 7142d45ef3
commit f012fb8592
15 changed files with 1006 additions and 370 deletions

View File

@@ -6,37 +6,45 @@ import (
"errors"
"fmt"
"io"
"strings"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
)
// associateJWTUsername extracts pending JWT username for the session and associates it with the session state.
// Returns the JWT username (empty if none) for logging purposes.
func (s *Server) associateJWTUsername(sess ssh.Session, sessionKey sessionKey) string {
key := newAuthKey(sess.User(), sess.RemoteAddr())
s.mu.Lock()
defer s.mu.Unlock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername == "" {
return ""
}
if state, exists := s.sessions[sessionKey]; exists {
state.jwtUsername = jwtUsername
}
delete(s.pendingAuthJWT, key)
return jwtUsername
}
// sessionHandler handles SSH sessions
func (s *Server) sessionHandler(session ssh.Session) {
sessionKey := s.registerSession(session)
key := newAuthKey(session.User(), session.RemoteAddr())
s.mu.Lock()
jwtUsername := s.pendingAuthJWT[key]
if jwtUsername != "" {
s.sessionJWTUsers[sessionKey] = jwtUsername
delete(s.pendingAuthJWT, key)
}
s.mu.Unlock()
sessionKey := s.registerSession(session, "")
jwtUsername := s.associateJWTUsername(session, sessionKey)
logger := log.WithField("session", sessionKey)
if jwtUsername != "" {
logger = logger.WithField("jwt_user", jwtUsername)
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
} else {
logger.Infof("SSH session started")
}
logger.Info("SSH session started")
sessionStart := time.Now()
defer s.unregisterSession(sessionKey, session)
defer s.unregisterSession(sessionKey)
defer func() {
duration := time.Since(sessionStart).Round(time.Millisecond)
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
@@ -65,27 +73,52 @@ func (s *Server) sessionHandler(session ssh.Session) {
// ssh <host> <cmd> - non-Pty command execution
s.handleCommand(logger, session, privilegeResult, nil)
default:
s.rejectInvalidSession(logger, session)
// ssh -T (or ssh -N) - no PTY, no command
s.handleNonInteractiveSession(logger, session)
}
}
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
logger.Debugf(errWriteSession, err)
// handleNonInteractiveSession handles sessions that have no PTY and no command.
// These are typically used for port forwarding (ssh -L/-R) or tunneling (ssh -N).
func (s *Server) handleNonInteractiveSession(logger *log.Entry, session ssh.Session) {
s.updateSessionType(session, cmdNonInteractive)
if !s.isPortForwardingEnabled() {
if _, err := io.WriteString(session, "port forwarding is disabled on this server\n"); err != nil {
logger.Debugf(errWriteSession, err)
}
if err := session.Exit(1); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-interactive session: port forwarding disabled")
return
}
if err := session.Exit(1); err != nil {
<-session.Context().Done()
if err := session.Exit(0); err != nil {
logSessionExitError(logger, err)
}
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
}
func (s *Server) registerSession(session ssh.Session) SessionKey {
func (s *Server) updateSessionType(session ssh.Session, sessionType string) {
s.mu.Lock()
defer s.mu.Unlock()
for _, state := range s.sessions {
if state.session == session {
state.sessionType = sessionType
return
}
}
}
func (s *Server) registerSession(session ssh.Session, sessionType string) sessionKey {
sessionID := session.Context().Value(ssh.ContextKeySessionID)
if sessionID == nil {
sessionID = fmt.Sprintf("%p", session)
}
// Create a short 4-byte identifier from the full session ID
hasher := sha256.New()
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
hash := hasher.Sum(nil)
@@ -93,43 +126,23 @@ func (s *Server) registerSession(session ssh.Session) SessionKey {
remoteAddr := session.RemoteAddr().String()
username := session.User()
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
sessionKey := sessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
s.mu.Lock()
s.sessions[sessionKey] = session
s.sessions[sessionKey] = &sessionState{
session: session,
sessionType: sessionType,
}
s.mu.Unlock()
return sessionKey
}
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
func (s *Server) unregisterSession(sessionKey sessionKey) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessions, sessionKey)
delete(s.sessionJWTUsers, sessionKey)
// Cancel all port forwarding connections for this session
var connectionsToCancel []ConnectionKey
for key := range s.sessionCancels {
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
connectionsToCancel = append(connectionsToCancel, key)
}
}
for _, key := range connectionsToCancel {
if cancelFunc, exists := s.sessionCancels[key]; exists {
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
cancelFunc()
delete(s.sessionCancels, key)
}
}
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
delete(s.sshConnections, sshConn)
}
}
s.mu.Unlock()
}
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {