mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Improve forwarding cancellation
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -288,8 +287,7 @@ type acceptResult struct {
|
||||
// handleRemoteForwardConnection handles a single remote port forwarding connection
|
||||
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
|
||||
sessionKey := s.findSessionKeyByContext(ctx)
|
||||
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
connID := fmt.Sprintf("pf-%s->%s:%d", remoteAddr, host, port)
|
||||
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"session": sessionKey,
|
||||
"conn": connID,
|
||||
@@ -307,6 +305,12 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
return
|
||||
}
|
||||
|
||||
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
|
||||
if err != nil {
|
||||
logger.Debugf("open forward channel: %v", err)
|
||||
@@ -344,51 +348,37 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
|
||||
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
|
||||
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
|
||||
done := make(chan struct{}, 2)
|
||||
closed := make(chan struct{})
|
||||
var closeOnce bool
|
||||
|
||||
go s.monitorSessionContext(ctx, channel, conn, closed, &closeOnce, logger)
|
||||
|
||||
go func() {
|
||||
defer func() { done <- struct{}{} }()
|
||||
if _, err := io.Copy(channel, conn); err != nil {
|
||||
logger.Debugf("copy error (conn->channel): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() { done <- struct{}{} }()
|
||||
if _, err := io.Copy(conn, channel); err != nil {
|
||||
logger.Debugf("copy error (channel->conn): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
case <-closed:
|
||||
// First copy finished, wait for second copy or context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
if !closeOnce {
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitorSessionContext watches for session cancellation and closes connections
|
||||
func (s *Server) monitorSessionContext(ctx context.Context, channel cryptossh.Channel, conn net.Conn, closed chan struct{}, closeOnce *bool, logger *log.Entry) {
|
||||
<-ctx.Done()
|
||||
logger.Debugf("session ended, closing connections")
|
||||
|
||||
if !*closeOnce {
|
||||
*closeOnce = true
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
close(closed)
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user