mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 01:36:46 +00:00
695 lines
19 KiB
Go
695 lines
19 KiB
Go
package client
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/crypto/ssh/knownhosts"
|
|
"golang.org/x/term"
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
|
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
"github.com/netbirdio/netbird/client/proto"
|
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
|
"github.com/netbirdio/netbird/util"
|
|
)
|
|
|
|
const (
|
|
// DefaultDaemonAddr is the default address for the NetBird daemon
|
|
DefaultDaemonAddr = "unix:///var/run/netbird.sock"
|
|
// DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
|
|
DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
|
|
)
|
|
|
|
// Client wraps crypto/ssh Client for simplified SSH operations
|
|
type Client struct {
|
|
client *ssh.Client
|
|
terminalState *term.State
|
|
terminalFd int
|
|
|
|
windowsStdoutMode uint32 // nolint:unused
|
|
windowsStdinMode uint32 // nolint:unused
|
|
}
|
|
|
|
func (c *Client) Close() error {
|
|
return c.client.Close()
|
|
}
|
|
|
|
func (c *Client) OpenTerminal(ctx context.Context) error {
|
|
session, err := c.client.NewSession()
|
|
if err != nil {
|
|
return fmt.Errorf("new session: %w", err)
|
|
}
|
|
defer func() {
|
|
if err := session.Close(); err != nil {
|
|
log.Debugf("session close error: %v", err)
|
|
}
|
|
}()
|
|
|
|
if err := c.setupTerminalMode(ctx, session); err != nil {
|
|
return err
|
|
}
|
|
|
|
c.setupSessionIO(session)
|
|
|
|
if err := session.Shell(); err != nil {
|
|
return fmt.Errorf("start shell: %w", err)
|
|
}
|
|
|
|
return c.waitForSession(ctx, session)
|
|
}
|
|
|
|
// setupSessionIO connects session streams to local terminal
|
|
func (c *Client) setupSessionIO(session *ssh.Session) {
|
|
session.Stdout = os.Stdout
|
|
session.Stderr = os.Stderr
|
|
session.Stdin = os.Stdin
|
|
}
|
|
|
|
// waitForSession waits for the session to complete with context cancellation
|
|
func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- session.Wait()
|
|
}()
|
|
|
|
defer c.restoreTerminal()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case err := <-done:
|
|
return c.handleSessionError(err)
|
|
}
|
|
}
|
|
|
|
// handleSessionError processes session termination errors
|
|
func (c *Client) handleSessionError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
var e *ssh.ExitError
|
|
var em *ssh.ExitMissingError
|
|
if !errors.As(err, &e) && !errors.As(err, &em) {
|
|
return fmt.Errorf("session wait: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// restoreTerminal restores the terminal to its original state
|
|
func (c *Client) restoreTerminal() {
|
|
if c.terminalState != nil {
|
|
_ = term.Restore(c.terminalFd, c.terminalState)
|
|
c.terminalState = nil
|
|
c.terminalFd = 0
|
|
}
|
|
|
|
if err := c.restoreWindowsConsoleState(); err != nil {
|
|
log.Debugf("restore Windows console state: %v", err)
|
|
}
|
|
}
|
|
|
|
// ExecuteCommand executes a command on the remote host and returns the output
|
|
func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
|
|
session, cleanup, err := c.createSession(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer cleanup()
|
|
|
|
output, err := session.CombinedOutput(command)
|
|
if err != nil {
|
|
var e *ssh.ExitError
|
|
var em *ssh.ExitMissingError
|
|
if !errors.As(err, &e) && !errors.As(err, &em) {
|
|
return output, fmt.Errorf("execute command: %w", err)
|
|
}
|
|
}
|
|
|
|
return output, nil
|
|
}
|
|
|
|
// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
|
|
func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
|
|
session, cleanup, err := c.createSession(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("create session: %w", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
c.setupSessionIO(session)
|
|
|
|
if err := session.Start(command); err != nil {
|
|
return fmt.Errorf("start command: %w", err)
|
|
}
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- session.Wait()
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = session.Signal(ssh.SIGTERM)
|
|
select {
|
|
case <-done:
|
|
return ctx.Err()
|
|
case <-time.After(100 * time.Millisecond):
|
|
return ctx.Err()
|
|
}
|
|
case err := <-done:
|
|
return c.handleCommandError(err)
|
|
}
|
|
}
|
|
|
|
// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
|
|
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
|
|
session, cleanup, err := c.createSession(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("create session: %w", err)
|
|
}
|
|
defer cleanup()
|
|
|
|
if err := c.setupTerminalMode(ctx, session); err != nil {
|
|
return fmt.Errorf("setup terminal mode: %w", err)
|
|
}
|
|
|
|
c.setupSessionIO(session)
|
|
|
|
if err := session.Start(command); err != nil {
|
|
return fmt.Errorf("start command: %w", err)
|
|
}
|
|
|
|
defer c.restoreTerminal()
|
|
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
done <- session.Wait()
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = session.Signal(ssh.SIGTERM)
|
|
select {
|
|
case <-done:
|
|
return ctx.Err()
|
|
case <-time.After(100 * time.Millisecond):
|
|
return ctx.Err()
|
|
}
|
|
case err := <-done:
|
|
return c.handleCommandError(err)
|
|
}
|
|
}
|
|
|
|
// handleCommandError processes command execution errors
|
|
func (c *Client) handleCommandError(err error) error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
var e *ssh.ExitError
|
|
var em *ssh.ExitMissingError
|
|
if errors.As(err, &e) || errors.As(err, &em) {
|
|
return err
|
|
}
|
|
|
|
return fmt.Errorf("execute command: %w", err)
|
|
}
|
|
|
|
// setupContextCancellation sets up context cancellation for a session
|
|
func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
|
|
done := make(chan struct{})
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
_ = session.Signal(ssh.SIGTERM)
|
|
_ = session.Close()
|
|
case <-done:
|
|
}
|
|
}()
|
|
return func() { close(done) }
|
|
}
|
|
|
|
// createSession creates a new SSH session with context cancellation setup
|
|
func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
|
|
session, err := c.client.NewSession()
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("new session: %w", err)
|
|
}
|
|
|
|
cancel := c.setupContextCancellation(ctx, session)
|
|
cleanup := func() {
|
|
cancel()
|
|
_ = session.Close()
|
|
}
|
|
|
|
return session, cleanup, nil
|
|
}
|
|
|
|
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
|
|
func getDefaultDaemonAddr() string {
|
|
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
|
|
return addr
|
|
}
|
|
if runtime.GOOS == "windows" {
|
|
return DefaultDaemonAddrWindows
|
|
}
|
|
return DefaultDaemonAddr
|
|
}
|
|
|
|
// DialOptions contains options for SSH connections
|
|
type DialOptions struct {
|
|
KnownHostsFile string
|
|
IdentityFile string
|
|
DaemonAddr string
|
|
SkipCachedToken bool
|
|
InsecureSkipVerify bool
|
|
NoBrowser bool
|
|
}
|
|
|
|
// Dial connects to the given ssh server with specified options
|
|
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
|
daemonAddr := opts.DaemonAddr
|
|
if daemonAddr == "" {
|
|
daemonAddr = getDefaultDaemonAddr()
|
|
}
|
|
opts.DaemonAddr = daemonAddr
|
|
|
|
hostKeyCallback, err := createHostKeyCallback(opts)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create host key callback: %w", err)
|
|
}
|
|
|
|
config := &ssh.ClientConfig{
|
|
User: user,
|
|
Timeout: 30 * time.Second,
|
|
HostKeyCallback: hostKeyCallback,
|
|
}
|
|
|
|
if opts.IdentityFile != "" {
|
|
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create SSH key auth: %w", err)
|
|
}
|
|
config.Auth = append(config.Auth, authMethod)
|
|
}
|
|
|
|
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser)
|
|
}
|
|
|
|
// dialSSH establishes an SSH connection without JWT authentication
|
|
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
|
dialer := &net.Dialer{}
|
|
conn, err := dialer.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
|
}
|
|
|
|
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
|
if err != nil {
|
|
if closeErr := conn.Close(); closeErr != nil {
|
|
log.Debugf("connection close after handshake failure: %v", closeErr)
|
|
}
|
|
return nil, fmt.Errorf("ssh handshake: %w", err)
|
|
}
|
|
|
|
client := ssh.NewClient(clientConn, chans, reqs)
|
|
return &Client{
|
|
client: client,
|
|
}, nil
|
|
}
|
|
|
|
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
|
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) {
|
|
host, portStr, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
|
}
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
|
}
|
|
|
|
detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
|
defer cancel()
|
|
|
|
dialer := &net.Dialer{}
|
|
serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("SSH server detection: %w", err)
|
|
}
|
|
|
|
if !serverType.RequiresJWT() {
|
|
return dialSSH(ctx, network, addr, config)
|
|
}
|
|
|
|
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
|
defer cancel()
|
|
|
|
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("request JWT token: %w", err)
|
|
}
|
|
|
|
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
|
|
return dialSSH(ctx, network, addr, configWithJWT)
|
|
}
|
|
|
|
// requestJWTToken requests a JWT token from the NetBird daemon
|
|
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) {
|
|
hint := profilemanager.GetLoginHint()
|
|
|
|
conn, err := connectToDaemon(daemonAddr)
|
|
if err != nil {
|
|
return "", fmt.Errorf("connect to daemon: %w", err)
|
|
}
|
|
defer conn.Close()
|
|
|
|
client := proto.NewDaemonServiceClient(conn)
|
|
|
|
var browserOpener func(string) error
|
|
if !noBrowser {
|
|
browserOpener = util.OpenBrowser
|
|
}
|
|
|
|
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener)
|
|
}
|
|
|
|
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
|
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
|
conn, err := connectToDaemon(daemonAddr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer func() {
|
|
if err := conn.Close(); err != nil {
|
|
log.Debugf("daemon connection close error: %v", err)
|
|
}
|
|
}()
|
|
|
|
client := proto.NewDaemonServiceClient(conn)
|
|
verifier := nbssh.NewDaemonHostKeyVerifier(client)
|
|
callback := nbssh.CreateHostKeyCallback(verifier)
|
|
return callback(hostname, remote, key)
|
|
}
|
|
|
|
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
|
addr := strings.TrimPrefix(daemonAddr, "tcp://")
|
|
|
|
conn, err := grpc.NewClient(
|
|
addr,
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
|
|
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
|
func getKnownHostsFiles() []string {
|
|
var files []string
|
|
|
|
// User's known_hosts file (highest priority)
|
|
if homeDir, err := os.UserHomeDir(); err == nil {
|
|
userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
|
|
files = append(files, userKnownHosts)
|
|
}
|
|
|
|
// NetBird managed known_hosts files
|
|
if runtime.GOOS == "windows" {
|
|
programData := os.Getenv("PROGRAMDATA")
|
|
if programData == "" {
|
|
programData = `C:\ProgramData`
|
|
}
|
|
netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
|
|
files = append(files, netbirdKnownHosts)
|
|
} else {
|
|
files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
|
|
files = append(files, "/etc/ssh/ssh_known_hosts")
|
|
}
|
|
|
|
return files
|
|
}
|
|
|
|
// createHostKeyCallback creates a host key verification callback
|
|
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
|
|
if opts.InsecureSkipVerify {
|
|
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
|
|
}
|
|
|
|
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
|
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
|
return nil
|
|
}
|
|
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
|
|
}, nil
|
|
}
|
|
|
|
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
|
if daemonAddr == "" {
|
|
return fmt.Errorf("no daemon address")
|
|
}
|
|
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
|
|
}
|
|
|
|
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
|
|
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
|
|
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
|
|
|
|
for _, callback := range hostKeyCallbacks {
|
|
if err := callback(hostname, remote, key); err == nil {
|
|
return nil
|
|
}
|
|
}
|
|
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
|
|
}
|
|
|
|
func getKnownHostsFilesList(knownHostsFile string) []string {
|
|
if knownHostsFile != "" {
|
|
return []string{knownHostsFile}
|
|
}
|
|
return getKnownHostsFiles()
|
|
}
|
|
|
|
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
|
|
var hostKeyCallbacks []ssh.HostKeyCallback
|
|
for _, file := range knownHostsFiles {
|
|
if callback, err := knownhosts.New(file); err == nil {
|
|
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
|
}
|
|
}
|
|
return hostKeyCallbacks
|
|
}
|
|
|
|
// createSSHKeyAuth creates SSH key authentication from a private key file
|
|
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
|
|
keyData, err := os.ReadFile(keyFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
|
|
}
|
|
|
|
signer, err := ssh.ParsePrivateKey(keyData)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parse SSH private key: %w", err)
|
|
}
|
|
|
|
return ssh.PublicKeys(signer), nil
|
|
}
|
|
|
|
// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
|
|
func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
|
|
localListener, err := net.Listen("tcp", localAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("listen on %s: %w", localAddr, err)
|
|
}
|
|
|
|
go func() {
|
|
defer func() {
|
|
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
|
log.Debugf("local listener close error: %v", err)
|
|
}
|
|
}()
|
|
for {
|
|
localConn, err := localListener.Accept()
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
|
|
go c.handleLocalForward(localConn, remoteAddr)
|
|
}
|
|
}()
|
|
|
|
<-ctx.Done()
|
|
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
|
log.Debugf("local listener close error: %v", err)
|
|
}
|
|
return ctx.Err()
|
|
}
|
|
|
|
// handleLocalForward handles a single local port forwarding connection
|
|
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
|
defer func() {
|
|
if err := localConn.Close(); err != nil {
|
|
log.Debugf("local port forwarding: close local connection: %v", err)
|
|
}
|
|
}()
|
|
|
|
channel, err := c.client.Dial("tcp", remoteAddr)
|
|
if err != nil {
|
|
var openErr *ssh.OpenChannelError
|
|
if errors.As(err, &openErr) && openErr.Reason == ssh.Prohibited {
|
|
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: port forwarding is disabled\n")
|
|
} else {
|
|
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
|
|
}
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := channel.Close(); err != nil {
|
|
log.Debugf("local port forwarding: close remote channel: %v", err)
|
|
}
|
|
}()
|
|
|
|
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
|
}
|
|
|
|
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
|
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
|
|
host, port, err := c.parseRemoteAddress(remoteAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("parse remote address: %w", err)
|
|
}
|
|
|
|
req := c.buildTCPIPForwardRequest(host, port)
|
|
if err := c.sendTCPIPForwardRequest(req); err != nil {
|
|
return fmt.Errorf("setup remote forward: %w", err)
|
|
}
|
|
|
|
go c.handleRemoteForwardChannels(ctx, localAddr)
|
|
|
|
<-ctx.Done()
|
|
|
|
if err := c.cancelTCPIPForwardRequest(req); err != nil {
|
|
return fmt.Errorf("cancel tcpip-forward: %w", err)
|
|
}
|
|
return ctx.Err()
|
|
}
|
|
|
|
// parseRemoteAddress parses host and port from remote address string
|
|
func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
|
|
host, portStr, err := net.SplitHostPort(remoteAddr)
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
|
|
}
|
|
|
|
port, err := strconv.Atoi(portStr)
|
|
if err != nil {
|
|
return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
|
|
}
|
|
|
|
return host, uint32(port), nil
|
|
}
|
|
|
|
// buildTCPIPForwardRequest creates a tcpip-forward request message
|
|
func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
|
|
return tcpipForwardMsg{
|
|
Host: host,
|
|
Port: port,
|
|
}
|
|
}
|
|
|
|
// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
|
|
func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
|
|
ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
|
|
if err != nil {
|
|
return fmt.Errorf("send tcpip-forward request: %w", err)
|
|
}
|
|
if !ok {
|
|
return fmt.Errorf("remote port forwarding denied by server")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cancelTCPIPForwardRequest cancels the tcpip-forward request
|
|
func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
|
|
_, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
|
|
if err != nil {
|
|
return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
|
|
func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
|
|
// Get the channel once - subsequent calls return nil!
|
|
channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
|
|
if channelRequests == nil {
|
|
log.Debugf("forwarded-tcpip channel type already being handled")
|
|
return
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case newChan := <-channelRequests:
|
|
if newChan != nil {
|
|
go c.handleRemoteForwardChannel(newChan, localAddr)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
|
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
|
channel, reqs, err := newChan.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := channel.Close(); err != nil {
|
|
log.Debugf("remote port forwarding: close remote channel: %v", err)
|
|
}
|
|
}()
|
|
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
localConn, err := net.Dial("tcp", localAddr)
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer func() {
|
|
if err := localConn.Close(); err != nil {
|
|
log.Debugf("remote port forwarding: close local connection: %v", err)
|
|
}
|
|
}()
|
|
|
|
nbssh.BidirectionalCopy(log.NewEntry(log.StandardLogger()), localConn, channel)
|
|
}
|
|
|
|
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
|
type tcpipForwardMsg struct {
|
|
Host string
|
|
Port uint32
|
|
}
|