mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 17:56:39 +00:00
Improve forwarding cancellation
This commit is contained in:
@@ -23,6 +23,13 @@ import (
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDaemonAddr is the default address for the NetBird daemon
|
||||
DefaultDaemonAddr = "unix:///var/run/netbird.sock"
|
||||
// DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
|
||||
DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client for simplified SSH operations
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
@@ -172,7 +179,7 @@ func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error
|
||||
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
@@ -335,7 +342,15 @@ func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (
|
||||
|
||||
// createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files
|
||||
func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) {
|
||||
return createHostKeyCallbackWithDaemonAddr(addr, "unix:///var/run/netbird.sock")
|
||||
daemonAddr := os.Getenv("NB_DAEMON_ADDR")
|
||||
if daemonAddr == "" {
|
||||
if runtime.GOOS == "windows" {
|
||||
daemonAddr = DefaultDaemonAddrWindows
|
||||
} else {
|
||||
daemonAddr = DefaultDaemonAddr
|
||||
}
|
||||
}
|
||||
return createHostKeyCallbackWithDaemonAddr(addr, daemonAddr)
|
||||
}
|
||||
|
||||
// createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address
|
||||
@@ -617,12 +632,12 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
|
||||
host, port, err := c.parseRemoteAddress(remoteAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
req := c.buildTCPIPForwardRequest(host, port)
|
||||
if err := c.sendTCPIPForwardRequest(req); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("setup remote forward: %w", err)
|
||||
}
|
||||
|
||||
go c.handleRemoteForwardChannels(ctx, localAddr)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -288,8 +287,7 @@ type acceptResult struct {
|
||||
// handleRemoteForwardConnection handles a single remote port forwarding connection
|
||||
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
|
||||
sessionKey := s.findSessionKeyByContext(ctx)
|
||||
remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
|
||||
connID := fmt.Sprintf("pf-%s->%s:%d", remoteAddr, host, port)
|
||||
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"session": sessionKey,
|
||||
"conn": connID,
|
||||
@@ -307,6 +305,12 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h
|
||||
return
|
||||
}
|
||||
|
||||
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
|
||||
if err != nil {
|
||||
logger.Debugf("open forward channel: %v", err)
|
||||
@@ -344,51 +348,37 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string,
|
||||
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
|
||||
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
|
||||
done := make(chan struct{}, 2)
|
||||
closed := make(chan struct{})
|
||||
var closeOnce bool
|
||||
|
||||
go s.monitorSessionContext(ctx, channel, conn, closed, &closeOnce, logger)
|
||||
|
||||
go func() {
|
||||
defer func() { done <- struct{}{} }()
|
||||
if _, err := io.Copy(channel, conn); err != nil {
|
||||
logger.Debugf("copy error (conn->channel): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() { done <- struct{}{} }()
|
||||
if _, err := io.Copy(conn, channel); err != nil {
|
||||
logger.Debugf("copy error (channel->conn): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
<-done
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
case <-closed:
|
||||
// First copy finished, wait for second copy or context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
if !closeOnce {
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// monitorSessionContext watches for session cancellation and closes connections
|
||||
func (s *Server) monitorSessionContext(ctx context.Context, channel cryptossh.Channel, conn net.Conn, closed chan struct{}, closeOnce *bool, logger *log.Entry) {
|
||||
<-ctx.Done()
|
||||
logger.Debugf("session ended, closing connections")
|
||||
|
||||
if !*closeOnce {
|
||||
*closeOnce = true
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
close(closed)
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user