Improve forwarding cancellation

This commit is contained in:
Viktor Liu
2025-08-26 22:22:15 +02:00
parent 77a352763d
commit 79d28b71ee
2 changed files with 41 additions and 36 deletions

View File

@@ -1,7 +1,6 @@
package server
import (
"context"
"encoding/binary"
"fmt"
"io"
@@ -288,8 +287,7 @@ type acceptResult struct {
// handleRemoteForwardConnection handles a single remote port forwarding connection
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
sessionKey := s.findSessionKeyByContext(ctx)
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
connID := fmt.Sprintf("pf-%s->%s:%d", remoteAddr, host, port)
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
logger := log.WithFields(log.Fields{
"session": sessionKey,
"conn": connID,
@@ -307,6 +305,12 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
return
}
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
if !ok {
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
return
}
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
if err != nil {
logger.Debugf("open forward channel: %v", err)
@@ -344,51 +348,37 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
done := make(chan struct{}, 2)
closed := make(chan struct{})
var closeOnce bool
go s.monitorSessionContext(ctx, channel, conn, closed, &closeOnce, logger)
go func() {
defer func() { done <- struct{}{} }()
if _, err := io.Copy(channel, conn); err != nil {
logger.Debugf("copy error (conn->channel): %v", err)
}
done <- struct{}{}
}()
go func() {
defer func() { done <- struct{}{} }()
if _, err := io.Copy(conn, channel); err != nil {
logger.Debugf("copy error (channel->conn): %v", err)
}
done <- struct{}{}
}()
<-done
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
case <-closed:
// First copy finished, wait for second copy or context cancellation
select {
case <-ctx.Done():
logger.Debugf("session ended, closing connections")
case <-done:
}
}
if !closeOnce {
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
}
}
// monitorSessionContext watches for session cancellation and closes connections
func (s *Server) monitorSessionContext(ctx context.Context, channel cryptossh.Channel, conn net.Conn, closed chan struct{}, closeOnce *bool, logger *log.Entry) {
<-ctx.Done()
logger.Debugf("session ended, closing connections")
if !*closeOnce {
*closeOnce = true
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
close(closed)
if err := channel.Close(); err != nil {
logger.Debugf("channel close error: %v", err)
}
if err := conn.Close(); err != nil {
logger.Debugf("connection close error: %v", err)
}
}