diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index ea2fa409a..0b7c1a88c 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -23,6 +23,13 @@ import ( "github.com/netbirdio/netbird/client/proto" ) +const ( + // DefaultDaemonAddr is the default address for the NetBird daemon + DefaultDaemonAddr = "unix:///var/run/netbird.sock" + // DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows + DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731" +) + // Client wraps crypto/ssh Client for simplified SSH operations type Client struct { client *ssh.Client @@ -172,7 +179,7 @@ func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error { session, cleanup, err := c.createSession(ctx) if err != nil { - return err + return fmt.Errorf("create session: %w", err) } defer cleanup() @@ -335,7 +342,15 @@ func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) ( // createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) { - return createHostKeyCallbackWithDaemonAddr(addr, "unix:///var/run/netbird.sock") + daemonAddr := os.Getenv("NB_DAEMON_ADDR") + if daemonAddr == "" { + if runtime.GOOS == "windows" { + daemonAddr = DefaultDaemonAddrWindows + } else { + daemonAddr = DefaultDaemonAddr + } + } + return createHostKeyCallbackWithDaemonAddr(addr, daemonAddr) } // createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address @@ -617,12 +632,12 @@ func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) { func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error { host, port, err := c.parseRemoteAddress(remoteAddr) if err != nil { - return err + return fmt.Errorf("parse remote address: %w", err) } req := c.buildTCPIPForwardRequest(host, port) if err := c.sendTCPIPForwardRequest(req); err != nil { - return err + return fmt.Errorf("setup remote forward: %w", err) } go c.handleRemoteForwardChannels(ctx, localAddr) diff --git a/client/ssh/server/port_forwarding.go b/client/ssh/server/port_forwarding.go index 4cdac9c4a..7eb249cc9 100644 --- a/client/ssh/server/port_forwarding.go +++ b/client/ssh/server/port_forwarding.go @@ -1,7 +1,6 @@ package server import ( - "context" "encoding/binary" "fmt" "io" @@ -288,8 +287,7 @@ type acceptResult struct { // handleRemoteForwardConnection handles a single remote port forwarding connection func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) { sessionKey := s.findSessionKeyByContext(ctx) - remoteAddr := conn.RemoteAddr().(*net.TCPAddr) - connID := fmt.Sprintf("pf-%s->%s:%d", remoteAddr, host, port) + connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port) logger := log.WithFields(log.Fields{ "session": sessionKey, "conn": connID, @@ -307,6 +305,12 @@ func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, h return } + remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr) + if !ok { + logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr()) + return + } + channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger) if err != nil { logger.Debugf("open forward channel: %v", err) @@ -344,51 +348,37 @@ func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, // proxyForwardConnection handles bidirectional data transfer between connection and SSH channel func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) { done := make(chan struct{}, 2) - closed := make(chan struct{}) - var closeOnce bool - - go s.monitorSessionContext(ctx, channel, conn, closed, &closeOnce, logger) go func() { - defer func() { done <- struct{}{} }() if _, err := io.Copy(channel, conn); err != nil { logger.Debugf("copy error (conn->channel): %v", err) } + done <- struct{}{} }() go func() { - defer func() { done <- struct{}{} }() if _, err := io.Copy(conn, channel); err != nil { logger.Debugf("copy error (channel->conn): %v", err) } + done <- struct{}{} }() - <-done select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") case <-done: - case <-closed: + // First copy finished, wait for second copy or context cancellation + select { + case <-ctx.Done(): + logger.Debugf("session ended, closing connections") + case <-done: + } } - if !closeOnce { - if err := channel.Close(); err != nil { - logger.Debugf("channel close error: %v", err) - } - } -} - -// monitorSessionContext watches for session cancellation and closes connections -func (s *Server) monitorSessionContext(ctx context.Context, channel cryptossh.Channel, conn net.Conn, closed chan struct{}, closeOnce *bool, logger *log.Entry) { - <-ctx.Done() - logger.Debugf("session ended, closing connections") - - if !*closeOnce { - *closeOnce = true - if err := channel.Close(); err != nil { - logger.Debugf("channel close error: %v", err) - } - if err := conn.Close(); err != nil { - logger.Debugf("connection close error: %v", err) - } - close(closed) + if err := channel.Close(); err != nil { + logger.Debugf("channel close error: %v", err) + } + if err := conn.Close(); err != nil { + logger.Debugf("connection close error: %v", err) } }