mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 15:46:38 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
699
client/ssh/client/client.go
Normal file
699
client/ssh/client/client.go
Normal file
@@ -0,0 +1,699 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/crypto/ssh/knownhosts"
|
||||
"golang.org/x/term"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultDaemonAddr is the default address for the NetBird daemon
|
||||
DefaultDaemonAddr = "unix:///var/run/netbird.sock"
|
||||
// DefaultDaemonAddrWindows is the default address for the NetBird daemon on Windows
|
||||
DefaultDaemonAddrWindows = "tcp://127.0.0.1:41731"
|
||||
)
|
||||
|
||||
// Client wraps crypto/ssh Client for simplified SSH operations
|
||||
type Client struct {
|
||||
client *ssh.Client
|
||||
terminalState *term.State
|
||||
terminalFd int
|
||||
|
||||
windowsStdoutMode uint32 // nolint:unused
|
||||
windowsStdinMode uint32 // nolint:unused
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
log.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
return c.waitForSession(ctx, session)
|
||||
}
|
||||
|
||||
// setupSessionIO connects session streams to local terminal
|
||||
func (c *Client) setupSessionIO(session *ssh.Session) {
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
}
|
||||
|
||||
// waitForSession waits for the session to complete with context cancellation
|
||||
func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case err := <-done:
|
||||
return c.handleSessionError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleSessionError processes session termination errors
|
||||
func (c *Client) handleSessionError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return fmt.Errorf("session wait: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreTerminal restores the terminal to its original state
|
||||
func (c *Client) restoreTerminal() {
|
||||
if c.terminalState != nil {
|
||||
_ = term.Restore(c.terminalFd, c.terminalState)
|
||||
c.terminalState = nil
|
||||
c.terminalFd = 0
|
||||
}
|
||||
|
||||
if err := c.restoreWindowsConsoleState(); err != nil {
|
||||
log.Debugf("restore Windows console state: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommand executes a command on the remote host and returns the output
|
||||
func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
output, err := session.CombinedOutput(command)
|
||||
if err != nil {
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if !errors.As(err, &e) && !errors.As(err, &em) {
|
||||
return output, fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
|
||||
func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
|
||||
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
|
||||
session, cleanup, err := c.createSession(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
if err := c.setupTerminalMode(ctx, session); err != nil {
|
||||
return fmt.Errorf("setup terminal mode: %w", err)
|
||||
}
|
||||
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
}
|
||||
|
||||
defer c.restoreTerminal()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- session.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
select {
|
||||
case <-done:
|
||||
return ctx.Err()
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
return ctx.Err()
|
||||
}
|
||||
case err := <-done:
|
||||
return c.handleCommandError(err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandError processes command execution errors
|
||||
func (c *Client) handleCommandError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var e *ssh.ExitError
|
||||
var em *ssh.ExitMissingError
|
||||
if errors.As(err, &e) || errors.As(err, &em) {
|
||||
return err
|
||||
}
|
||||
|
||||
return fmt.Errorf("execute command: %w", err)
|
||||
}
|
||||
|
||||
// setupContextCancellation sets up context cancellation for a session
|
||||
func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = session.Signal(ssh.SIGTERM)
|
||||
_ = session.Close()
|
||||
case <-done:
|
||||
}
|
||||
}()
|
||||
return func() { close(done) }
|
||||
}
|
||||
|
||||
// createSession creates a new SSH session with context cancellation setup
|
||||
func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
cancel := c.setupContextCancellation(ctx, session)
|
||||
cleanup := func() {
|
||||
cancel()
|
||||
_ = session.Close()
|
||||
}
|
||||
|
||||
return session, cleanup, nil
|
||||
}
|
||||
|
||||
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
|
||||
func getDefaultDaemonAddr() string {
|
||||
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
|
||||
return addr
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultDaemonAddrWindows
|
||||
}
|
||||
return DefaultDaemonAddr
|
||||
}
|
||||
|
||||
// DialOptions contains options for SSH connections
|
||||
type DialOptions struct {
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
DaemonAddr string
|
||||
SkipCachedToken bool
|
||||
InsecureSkipVerify bool
|
||||
}
|
||||
|
||||
// Dial connects to the given ssh server with specified options
|
||||
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||
daemonAddr := opts.DaemonAddr
|
||||
if daemonAddr == "" {
|
||||
daemonAddr = getDefaultDaemonAddr()
|
||||
}
|
||||
opts.DaemonAddr = daemonAddr
|
||||
|
||||
hostKeyCallback, err := createHostKeyCallback(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
}
|
||||
|
||||
if opts.IdentityFile != "" {
|
||||
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create SSH key auth: %w", err)
|
||||
}
|
||||
config.Auth = append(config.Auth, authMethod)
|
||||
}
|
||||
|
||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||
}
|
||||
|
||||
// dialSSH establishes an SSH connection without JWT authentication
|
||||
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
log.Debugf("connection close after handshake failure: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
client := ssh.NewClient(clientConn, chans, reqs)
|
||||
return &Client{
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||
}
|
||||
|
||||
if !serverType.RequiresJWT() {
|
||||
return dialSSH(ctx, network, addr, config)
|
||||
}
|
||||
|
||||
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||
}
|
||||
|
||||
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
|
||||
return dialSSH(ctx, network, addr, configWithJWT)
|
||||
}
|
||||
|
||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||
hint := profilemanager.GetLoginHint()
|
||||
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
|
||||
}
|
||||
|
||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("daemon connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(client)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
|
||||
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
||||
addr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
|
||||
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
||||
func getKnownHostsFiles() []string {
|
||||
var files []string
|
||||
|
||||
// User's known_hosts file (highest priority)
|
||||
if homeDir, err := os.UserHomeDir(); err == nil {
|
||||
userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
|
||||
files = append(files, userKnownHosts)
|
||||
}
|
||||
|
||||
// NetBird managed known_hosts files
|
||||
if runtime.GOOS == "windows" {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
|
||||
files = append(files, netbirdKnownHosts)
|
||||
} else {
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
|
||||
files = append(files, "/etc/ssh/ssh_known_hosts")
|
||||
}
|
||||
|
||||
return files
|
||||
}
|
||||
|
||||
// createHostKeyCallback creates a host key verification callback
|
||||
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||
if opts.InsecureSkipVerify {
|
||||
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||
return nil
|
||||
}
|
||||
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
if daemonAddr == "" {
|
||||
return fmt.Errorf("no daemon address")
|
||||
}
|
||||
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
|
||||
}
|
||||
|
||||
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
|
||||
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
|
||||
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
|
||||
|
||||
for _, callback := range hostKeyCallbacks {
|
||||
if err := callback(hostname, remote, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("host key verification failed: key for %s not found in any known_hosts file", hostname)
|
||||
}
|
||||
|
||||
func getKnownHostsFilesList(knownHostsFile string) []string {
|
||||
if knownHostsFile != "" {
|
||||
return []string{knownHostsFile}
|
||||
}
|
||||
return getKnownHostsFiles()
|
||||
}
|
||||
|
||||
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
|
||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
||||
for _, file := range knownHostsFiles {
|
||||
if callback, err := knownhosts.New(file); err == nil {
|
||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
||||
}
|
||||
}
|
||||
return hostKeyCallbacks
|
||||
}
|
||||
|
||||
// createSSHKeyAuth creates SSH key authentication from a private key file
|
||||
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
|
||||
keyData, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse SSH private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.PublicKeys(signer), nil
|
||||
}
|
||||
|
||||
// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
|
||||
func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
|
||||
localListener, err := net.Listen("tcp", localAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen on %s: %w", localAddr, err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("local listener close error: %v", err)
|
||||
}
|
||||
}()
|
||||
for {
|
||||
localConn, err := localListener.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleLocalForward(localConn, remoteAddr)
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
if err := localListener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
||||
log.Debugf("local listener close error: %v", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// handleLocalForward handles a single local port forwarding connection
|
||||
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
channel, err := c.client.Dial("tcp", remoteAddr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "administratively prohibited") {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
|
||||
} else {
|
||||
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("local forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("local forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
|
||||
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
|
||||
host, port, err := c.parseRemoteAddress(remoteAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
req := c.buildTCPIPForwardRequest(host, port)
|
||||
if err := c.sendTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("setup remote forward: %w", err)
|
||||
}
|
||||
|
||||
go c.handleRemoteForwardChannels(ctx, localAddr)
|
||||
|
||||
<-ctx.Done()
|
||||
|
||||
if err := c.cancelTCPIPForwardRequest(req); err != nil {
|
||||
return fmt.Errorf("cancel tcpip-forward: %w", err)
|
||||
}
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// parseRemoteAddress parses host and port from remote address string
|
||||
func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
|
||||
host, portStr, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
return host, uint32(port), nil
|
||||
}
|
||||
|
||||
// buildTCPIPForwardRequest creates a tcpip-forward request message
|
||||
func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
|
||||
return tcpipForwardMsg{
|
||||
Host: host,
|
||||
Port: port,
|
||||
}
|
||||
}
|
||||
|
||||
// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
|
||||
func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send tcpip-forward request: %w", err)
|
||||
}
|
||||
if !ok {
|
||||
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cancelTCPIPForwardRequest cancels the tcpip-forward request
|
||||
func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
|
||||
_, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
|
||||
if err != nil {
|
||||
return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
|
||||
func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
|
||||
// Get the channel once - subsequent calls return nil!
|
||||
channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
|
||||
if channelRequests == nil {
|
||||
log.Debugf("forwarded-tcpip channel type already being handled")
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case newChan := <-channelRequests:
|
||||
if newChan != nil {
|
||||
go c.handleRemoteForwardChannel(newChan, localAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
|
||||
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
|
||||
channel, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := channel.Close(); err != nil {
|
||||
log.Debugf("remote channel close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
localConn, err := net.Dial("tcp", localAddr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := localConn.Close(); err != nil {
|
||||
log.Debugf("local connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(localConn, channel); err != nil {
|
||||
log.Debugf("remote forward copy error (remote->local): %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(channel, localConn); err != nil {
|
||||
log.Debugf("remote forward copy error (local->remote): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward requests
|
||||
type tcpipForwardMsg struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
512
client/ssh/client/client_test.go
Normal file
512
client/ssh/client/client_test.go
Normal file
@@ -0,0 +1,512 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and start server
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Test Dial
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Verify client is connected
|
||||
assert.NotNil(t, client.client)
|
||||
}
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
t.Run("ExecuteCommand captures output", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo hello")
|
||||
assert.NoError(t, err)
|
||||
assert.Contains(t, string(output), "hello")
|
||||
})
|
||||
|
||||
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
|
||||
err := client.ExecuteCommandWithIO(ctx, "echo world")
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("commands with flags work", func(t *testing.T) {
|
||||
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
|
||||
})
|
||||
|
||||
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
|
||||
var testCmd string
|
||||
if runtime.GOOS == "windows" {
|
||||
testCmd = "echo hello | Select-String notfound"
|
||||
} else {
|
||||
testCmd = "echo 'hello' | grep 'notfound'"
|
||||
}
|
||||
_, err := client.ExecuteCommand(ctx, testCmd)
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Generate client key for multiple connections
|
||||
|
||||
const numClients = 3
|
||||
clients := make([]*Client, numClients)
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
for i := 0; i < numClients; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
cancel()
|
||||
require.NoError(t, err, "Client %d should connect successfully", i)
|
||||
clients[i] = client
|
||||
}
|
||||
|
||||
for i, client := range clients {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err, "Client %d should close without error", i)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
server, serverAddr, _ := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "timeout"),
|
||||
"Expected timeout-related error, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cmdCancel()
|
||||
|
||||
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
|
||||
if err != nil {
|
||||
var exitMissingErr *cryptossh.ExitMissingError
|
||||
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
errors.As(err, &exitMissingErr)
|
||||
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
if client != nil {
|
||||
require.NoError(t, client.Close(), "Client should close without error")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_TerminalState(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
assert.Nil(t, client.terminalState)
|
||||
assert.Equal(t, 0, client.terminalFd)
|
||||
|
||||
client.restoreTerminal()
|
||||
assert.Nil(t, client.terminalState)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.OpenTerminal(ctx)
|
||||
// In test environment without a real terminal, this may complete quickly or timeout
|
||||
// Both behaviors are acceptable for testing terminal state management
|
||||
if err != nil {
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "console"),
|
||||
"Should timeout or have console error on Windows")
|
||||
} else {
|
||||
// On Unix systems in test environment, we may get various errors
|
||||
// including timeouts or terminal-related errors
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "context deadline exceeded") ||
|
||||
strings.Contains(err.Error(), "terminal") ||
|
||||
strings.Contains(err.Error(), "pty"),
|
||||
"Expected timeout or terminal-related error, got: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return server, serverAddr, client
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwarding(t *testing.T) {
|
||||
server, _, client := setupTestSSHServerAndClient(t)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
t.Run("local forwarding times out gracefully", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
errors.Is(err, context.DeadlineExceeded) ||
|
||||
errors.Is(err, context.Canceled) ||
|
||||
strings.Contains(err.Error(), "connection"),
|
||||
"Expected context or connection error")
|
||||
})
|
||||
|
||||
t.Run("remote forwarding denied", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
assert.True(t,
|
||||
strings.Contains(err.Error(), "denied") ||
|
||||
strings.Contains(err.Error(), "disabled"),
|
||||
"Should be denied by default")
|
||||
})
|
||||
|
||||
t.Run("invalid addresses fail", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080")
|
||||
assert.Error(t, err)
|
||||
|
||||
err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping data transfer test in short mode")
|
||||
}
|
||||
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// Port forwarding requires the actual current user, not test user
|
||||
realUser, err := getRealCurrentUser()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Skip if running as system account that can't do port forwarding
|
||||
if testutil.IsSystemAccount(realUser) {
|
||||
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
|
||||
}
|
||||
|
||||
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
|
||||
InsecureSkipVerify: true, // Skip host key verification for test
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := testServer.Close(); err != nil {
|
||||
t.Logf("test server close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
testServerAddr := testServer.Addr().String()
|
||||
expectedResponse := "Hello, World!"
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testServer.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
t.Logf("connection read error: %v", err)
|
||||
return
|
||||
}
|
||||
if _, err := c.Write([]byte(expectedResponse)); err != nil {
|
||||
t.Logf("connection write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
localAddr := localListener.Addr().String()
|
||||
if err := localListener.Close(); err != nil {
|
||||
t.Logf("local listener close error: %v", err)
|
||||
}
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
err := client.LocalPortForward(ctx, localAddr, testServerAddr)
|
||||
if err != nil && !errors.Is(err, context.Canceled) {
|
||||
if isWindowsPrivilegeError(err) {
|
||||
t.Logf("Port forward failed due to Windows privilege restrictions: %v", err)
|
||||
} else {
|
||||
t.Logf("Port forward error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
_, err = conn.Write([]byte("test"))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
|
||||
t.Logf("set read deadline error: %v", err)
|
||||
}
|
||||
response := make([]byte, len(expectedResponse))
|
||||
n, err := io.ReadFull(conn, response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(expectedResponse), n)
|
||||
assert.Equal(t, expectedResponse, string(response))
|
||||
}
|
||||
|
||||
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
|
||||
func getRealCurrentUser() (string, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
}
|
||||
|
||||
if username := os.Getenv("USER"); username != "" {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("unable to determine current user")
|
||||
}
|
||||
|
||||
// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions
|
||||
func isWindowsPrivilegeError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := strings.ToLower(err.Error())
|
||||
return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD
|
||||
strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess)
|
||||
strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser)
|
||||
strings.Contains(errStr, "privilege") ||
|
||||
strings.Contains(errStr, "access denied") ||
|
||||
strings.Contains(errStr, "user authentication failed")
|
||||
}
|
||||
127
client/ssh/client/terminal_unix.go
Normal file
127
client/ssh/client/terminal_unix.go
Normal file
@@ -0,0 +1,127 @@
|
||||
//go:build !windows
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
|
||||
stdinFd := int(os.Stdin.Fd())
|
||||
|
||||
if !term.IsTerminal(stdinFd) {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
fd := int(os.Stdin.Fd())
|
||||
|
||||
state, err := term.MakeRaw(fd)
|
||||
if err != nil {
|
||||
return c.setupNonTerminalMode(ctx, session)
|
||||
}
|
||||
|
||||
if err := c.setupTerminal(session, fd); err != nil {
|
||||
if restoreErr := term.Restore(fd, state); restoreErr != nil {
|
||||
log.Debugf("restore terminal state: %v", restoreErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c.terminalState = state
|
||||
c.terminalFd = fd
|
||||
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
|
||||
|
||||
go func() {
|
||||
defer signal.Stop(sigChan)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
case sig := <-sigChan:
|
||||
if err := term.Restore(fd, state); err != nil {
|
||||
log.Debugf("restore terminal state: %v", err)
|
||||
}
|
||||
signal.Reset(sig)
|
||||
s, ok := sig.(syscall.Signal)
|
||||
if !ok {
|
||||
log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig)
|
||||
return
|
||||
}
|
||||
if err := syscall.Kill(syscall.Getpid(), s); err != nil {
|
||||
log.Debugf("kill process with signal %v: %v", s, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// restoreWindowsConsoleState is a no-op on Unix systems
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminal(session *ssh.Session, fd int) error {
|
||||
w, h, err := term.GetSize(fd)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get terminal size: %w", err)
|
||||
}
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
// Ctrl+C
|
||||
ssh.VINTR: 3,
|
||||
// Ctrl+\
|
||||
ssh.VQUIT: 28,
|
||||
// Backspace
|
||||
ssh.VERASE: 127,
|
||||
// Ctrl+U
|
||||
ssh.VKILL: 21,
|
||||
// Ctrl+D
|
||||
ssh.VEOF: 4,
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
// Ctrl+Q
|
||||
ssh.VSTART: 17,
|
||||
// Ctrl+S
|
||||
ssh.VSTOP: 19,
|
||||
// Ctrl+Z
|
||||
ssh.VSUSP: 26,
|
||||
// Ctrl+O
|
||||
ssh.VDISCARD: 15,
|
||||
// Ctrl+R
|
||||
ssh.VREPRINT: 18,
|
||||
// Ctrl+W
|
||||
ssh.VWERASE: 23,
|
||||
// Ctrl+V
|
||||
ssh.VLNEXT: 22,
|
||||
}
|
||||
|
||||
terminal := os.Getenv("TERM")
|
||||
if terminal == "" {
|
||||
terminal = "xterm-256color"
|
||||
}
|
||||
|
||||
if err := session.RequestPty(terminal, h, w, modes); err != nil {
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
265
client/ssh/client/terminal_windows.go
Normal file
265
client/ssh/client/terminal_windows.go
Normal file
@@ -0,0 +1,265 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
enableProcessedInput = 0x0001
|
||||
enableLineInput = 0x0002
|
||||
enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
|
||||
enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
|
||||
enableVirtualTerminalInput = 0x0200
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
|
||||
)
|
||||
|
||||
// ConsoleUnavailableError indicates that Windows console handles are not available
|
||||
// (e.g., in CI environments where stdout/stdin are redirected)
|
||||
type ConsoleUnavailableError struct {
|
||||
Operation string
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Error() string {
|
||||
return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err)
|
||||
}
|
||||
|
||||
func (e *ConsoleUnavailableError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type coord struct {
|
||||
x, y int16
|
||||
}
|
||||
|
||||
type smallRect struct {
|
||||
left, top, right, bottom int16
|
||||
}
|
||||
|
||||
type consoleScreenBufferInfo struct {
|
||||
size coord
|
||||
cursorPosition coord
|
||||
attributes uint16
|
||||
window smallRect
|
||||
maximumWindowSize coord
|
||||
}
|
||||
|
||||
func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error {
|
||||
if err := c.saveWindowsConsoleState(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("console unavailable, not requesting PTY: %v", err)
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("save console state: %w", err)
|
||||
}
|
||||
|
||||
if err := c.enableWindowsVirtualTerminal(); err != nil {
|
||||
var consoleErr *ConsoleUnavailableError
|
||||
if errors.As(err, &consoleErr) {
|
||||
log.Debugf("virtual terminal unavailable: %v", err)
|
||||
} else {
|
||||
return fmt.Errorf("failed to enable virtual terminal: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
w, h := c.getWindowsConsoleSize()
|
||||
|
||||
modes := ssh.TerminalModes{
|
||||
ssh.ECHO: 1,
|
||||
ssh.TTY_OP_ISPEED: 14400,
|
||||
ssh.TTY_OP_OSPEED: 14400,
|
||||
ssh.ICRNL: 1,
|
||||
ssh.OPOST: 1,
|
||||
ssh.ONLCR: 1,
|
||||
ssh.ISIG: 1,
|
||||
ssh.ICANON: 1,
|
||||
ssh.VINTR: 3, // Ctrl+C
|
||||
ssh.VQUIT: 28, // Ctrl+\
|
||||
ssh.VERASE: 127, // Backspace
|
||||
ssh.VKILL: 21, // Ctrl+U
|
||||
ssh.VEOF: 4, // Ctrl+D
|
||||
ssh.VEOL: 0,
|
||||
ssh.VEOL2: 0,
|
||||
ssh.VSTART: 17, // Ctrl+Q
|
||||
ssh.VSTOP: 19, // Ctrl+S
|
||||
ssh.VSUSP: 26, // Ctrl+Z
|
||||
ssh.VDISCARD: 15, // Ctrl+O
|
||||
ssh.VWERASE: 23, // Ctrl+W
|
||||
ssh.VLNEXT: 22, // Ctrl+V
|
||||
ssh.VREPRINT: 18, // Ctrl+R
|
||||
}
|
||||
|
||||
if err := session.RequestPty("xterm-256color", h, w, modes); err != nil {
|
||||
if restoreErr := c.restoreWindowsConsoleState(); restoreErr != nil {
|
||||
log.Debugf("restore Windows console state: %v", restoreErr)
|
||||
}
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) saveWindowsConsoleState() error {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in saveWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
var stdoutMode, stdinMode uint32
|
||||
|
||||
ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdout console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get stdin console mode: %v", err)
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode",
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 1
|
||||
c.windowsStdoutMode = stdoutMode
|
||||
c.windowsStdinMode = stdinMode
|
||||
|
||||
log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) enableWindowsVirtualTerminal() (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
var mode uint32
|
||||
|
||||
ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdout console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode |= enableVirtualTerminalProcessing
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "enable virtual terminal processing",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode)))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "get stdin console mode for VT",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput)
|
||||
mode |= enableVirtualTerminalInput
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode))
|
||||
if ret == 0 {
|
||||
return &ConsoleUnavailableError{
|
||||
Operation: "set stdin raw mode",
|
||||
Err: winErr,
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("enabled Windows virtual terminal processing")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) getWindowsConsoleSize() (int, int) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Debugf("panic in getWindowsConsoleSize: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
var csbi consoleScreenBufferInfo
|
||||
|
||||
ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi)))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to get console buffer info, using defaults: %v", err)
|
||||
return 80, 24
|
||||
}
|
||||
|
||||
width := int(csbi.window.right - csbi.window.left + 1)
|
||||
height := int(csbi.window.bottom - csbi.window.top + 1)
|
||||
|
||||
log.Debugf("Windows console size: %dx%d", width, height)
|
||||
return width, height
|
||||
}
|
||||
|
||||
func (c *Client) restoreWindowsConsoleState() error {
|
||||
var err error
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
if c.terminalFd != 1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
stdout := syscall.Handle(os.Stdout.Fd())
|
||||
stdin := syscall.Handle(os.Stdin.Fd())
|
||||
|
||||
ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdout console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdout console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode))
|
||||
if ret == 0 {
|
||||
log.Debugf("failed to restore stdin console mode: %v", winErr)
|
||||
if err == nil {
|
||||
err = fmt.Errorf("restore stdin console mode: %w", winErr)
|
||||
}
|
||||
}
|
||||
|
||||
c.terminalFd = 0
|
||||
c.windowsStdoutMode = 0
|
||||
c.windowsStdinMode = 0
|
||||
|
||||
log.Debugf("restored Windows console state")
|
||||
return err
|
||||
}
|
||||
Reference in New Issue
Block a user