Complete overhaul

This commit is contained in:
Viktor Liu
2025-06-24 12:19:53 +02:00
parent f56075ca15
commit 9d1554f9f7
74 changed files with 16626 additions and 4524 deletions

712
client/ssh/client/client.go Normal file
View File

@@ -0,0 +1,712 @@
package client
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/knownhosts"
"golang.org/x/term"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
)
// Client wraps crypto/ssh Client for simplified SSH operations
type Client struct {
client *ssh.Client
terminalState *term.State
terminalFd int
windowsStdoutMode uint32 // nolint:unused
windowsStdinMode uint32 // nolint:unused
}
// Close terminates the SSH connection
func (c *Client) Close() error {
return c.client.Close()
}
// OpenTerminal opens an interactive terminal session
func (c *Client) OpenTerminal(ctx context.Context) error {
session, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("new session: %w", err)
}
defer func() {
if err := session.Close(); err != nil {
log.Debugf("session close error: %v", err)
}
}()
if err := c.setupTerminalMode(ctx, session); err != nil {
return err
}
c.setupSessionIO(ctx, session)
if err := session.Shell(); err != nil {
return fmt.Errorf("start shell: %w", err)
}
return c.waitForSession(ctx, session)
}
// setupSessionIO connects session streams to local terminal
func (c *Client) setupSessionIO(ctx context.Context, session *ssh.Session) {
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
}
// waitForSession waits for the session to complete with context cancellation
func (c *Client) waitForSession(ctx context.Context, session *ssh.Session) error {
done := make(chan error, 1)
go func() {
done <- session.Wait()
}()
defer c.restoreTerminal()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
return c.handleSessionError(err)
}
}
// handleSessionError processes session termination errors
func (c *Client) handleSessionError(err error) error {
if err == nil {
return nil
}
var e *ssh.ExitError
var em *ssh.ExitMissingError
if !errors.As(err, &e) && !errors.As(err, &em) {
return fmt.Errorf("session wait: %w", err)
}
return nil
}
// restoreTerminal restores the terminal to its original state
func (c *Client) restoreTerminal() {
if c.terminalState != nil {
_ = term.Restore(c.terminalFd, c.terminalState)
c.terminalState = nil
c.terminalFd = 0
}
if err := c.restoreWindowsConsoleState(); err != nil {
log.Debugf("restore Windows console state: %v", err)
}
}
// ExecuteCommand executes a command on the remote host and returns the output
func (c *Client) ExecuteCommand(ctx context.Context, command string) ([]byte, error) {
session, cleanup, err := c.createSession(ctx)
if err != nil {
return nil, err
}
defer cleanup()
output, err := session.CombinedOutput(command)
if err != nil {
var e *ssh.ExitError
var em *ssh.ExitMissingError
if !errors.As(err, &e) && !errors.As(err, &em) {
return output, fmt.Errorf("execute command: %w", err)
}
}
return output, nil
}
// ExecuteCommandWithIO executes a command with interactive I/O connected to local terminal
func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error {
session, cleanup, err := c.createSession(ctx)
if err != nil {
return fmt.Errorf("create session: %w", err)
}
defer cleanup()
c.setupSessionIO(ctx, session)
if err := session.Start(command); err != nil {
return fmt.Errorf("start command: %w", err)
}
done := make(chan error, 1)
go func() {
done <- session.Wait()
}()
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGTERM)
select {
case <-done:
return ctx.Err()
case <-time.After(100 * time.Millisecond):
return ctx.Err()
}
case err := <-done:
return c.handleCommandError(err)
}
}
// ExecuteCommandWithPTY executes a command with a pseudo-terminal for interactive sessions
func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) error {
session, cleanup, err := c.createSession(ctx)
if err != nil {
return err
}
defer cleanup()
if err := c.setupTerminalMode(ctx, session); err != nil {
return fmt.Errorf("setup terminal mode: %w", err)
}
c.setupSessionIO(ctx, session)
if err := session.Start(command); err != nil {
return fmt.Errorf("start command: %w", err)
}
defer c.restoreTerminal()
done := make(chan error, 1)
go func() {
done <- session.Wait()
}()
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGTERM)
select {
case <-done:
return ctx.Err()
case <-time.After(100 * time.Millisecond):
return ctx.Err()
}
case err := <-done:
return c.handleCommandError(err)
}
}
// handleCommandError processes command execution errors, treating exit codes as normal
func (c *Client) handleCommandError(err error) error {
if err == nil {
return nil
}
var e *ssh.ExitError
var em *ssh.ExitMissingError
if !errors.As(err, &e) && !errors.As(err, &em) {
return fmt.Errorf("execute command: %w", err)
}
return nil
}
// setupContextCancellation sets up context cancellation for a session
func (c *Client) setupContextCancellation(ctx context.Context, session *ssh.Session) func() {
done := make(chan struct{})
go func() {
select {
case <-ctx.Done():
_ = session.Signal(ssh.SIGTERM)
_ = session.Close()
case <-done:
}
}()
return func() { close(done) }
}
// createSession creates a new SSH session with context cancellation setup
func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error) {
session, err := c.client.NewSession()
if err != nil {
return nil, nil, fmt.Errorf("new session: %w", err)
}
cancel := c.setupContextCancellation(ctx, session)
cleanup := func() {
cancel()
_ = session.Close()
}
return session, cleanup, nil
}
// Dial connects to the given ssh server with proper host key verification
func Dial(ctx context.Context, addr, user string) (*Client, error) {
hostKeyCallback, err := createHostKeyCallback(addr)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
}
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: hostKeyCallback,
}
return dial(ctx, "tcp", addr, config)
}
// DialInsecure connects to the given ssh server without host key verification (for testing only)
func DialInsecure(ctx context.Context, addr, user string) (*Client, error) {
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
return dial(ctx, "tcp", addr, config)
}
// DialOptions contains options for SSH connections
type DialOptions struct {
KnownHostsFile string
IdentityFile string
DaemonAddr string
}
// DialWithOptions connects to the given ssh server with specified options
func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
hostKeyCallback, err := createHostKeyCallbackWithOptions(addr, opts)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
}
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: hostKeyCallback,
}
// Add SSH key authentication if identity file is specified
if opts.IdentityFile != "" {
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
if err != nil {
return nil, fmt.Errorf("create SSH key auth: %w", err)
}
config.Auth = append(config.Auth, authMethod)
}
return dial(ctx, "tcp", addr, config)
}
// dial establishes an SSH connection
func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
return nil, fmt.Errorf("dial %s: %w", addr, err)
}
clientConn, chans, reqs, err := ssh.NewClientConn(conn, addr, config)
if err != nil {
if closeErr := conn.Close(); closeErr != nil {
log.Debugf("connection close after handshake failure: %v", closeErr)
}
return nil, fmt.Errorf("ssh handshake: %w", err)
}
client := ssh.NewClient(clientConn, chans, reqs)
return &Client{
client: client,
}, nil
}
// createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files
func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) {
return createHostKeyCallbackWithDaemonAddr(addr, "unix:///var/run/netbird.sock")
}
// createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address
func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCallback, error) {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// First try to get host key from NetBird daemon
if err := verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr); err == nil {
return nil
}
// Fallback to known_hosts files
knownHostsFiles := getKnownHostsFiles()
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
// Try each known_hosts callback
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
}, nil
}
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
// Connect to NetBird daemon using the same logic as CLI
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
conn, err := grpc.DialContext(
ctx,
strings.TrimPrefix(daemonAddr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err)
return fmt.Errorf("failed to connect to NetBird daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("daemon connection close error: %v", err)
}
}()
client := proto.NewDaemonServiceClient(conn)
// Try both hostname and IP address from remote.String()
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
for _, addr := range addresses {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: addr,
})
cancel()
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
if err != nil {
log.Debugf("daemon query error for %s: %v", addr, err)
continue
}
if !response.GetFound() {
log.Debugf("SSH host key not found in daemon for address: %s", addr)
continue
}
// Parse the stored SSH host key
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(response.GetSshHostKey())
if err != nil {
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
continue
}
// Compare the keys
if key.Type() == storedKey.Type() && string(key.Marshal()) == string(storedKey.Marshal()) {
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
return nil
} else {
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), key.Type())
}
}
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
}
// getKnownHostsFiles returns paths to known_hosts files in order of preference
func getKnownHostsFiles() []string {
var files []string
// User's known_hosts file (highest priority)
if homeDir, err := os.UserHomeDir(); err == nil {
userKnownHosts := filepath.Join(homeDir, ".ssh", "known_hosts")
files = append(files, userKnownHosts)
}
// NetBird managed known_hosts files
if runtime.GOOS == "windows" {
programData := os.Getenv("PROGRAMDATA")
if programData == "" {
programData = `C:\ProgramData`
}
netbirdKnownHosts := filepath.Join(programData, "ssh", "ssh_known_hosts.d", "99-netbird")
files = append(files, netbirdKnownHosts)
} else {
files = append(files, "/etc/ssh/ssh_known_hosts.d/99-netbird")
files = append(files, "/etc/ssh/ssh_known_hosts")
}
return files
}
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// First try to get host key from NetBird daemon (if daemon address provided)
if opts.DaemonAddr != "" {
if err := verifyHostKeyViaDaemon(hostname, remote, key, opts.DaemonAddr); err == nil {
return nil
}
}
// Fallback to known_hosts files
var knownHostsFiles []string
if opts.KnownHostsFile != "" {
knownHostsFiles = append(knownHostsFiles, opts.KnownHostsFile)
} else {
knownHostsFiles = getKnownHostsFiles()
}
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
// Try each known_hosts callback
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
}, nil
}
// createSSHKeyAuth creates SSH key authentication from a private key file
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
keyData, err := os.ReadFile(keyFile)
if err != nil {
return nil, fmt.Errorf("read SSH key file %s: %w", keyFile, err)
}
signer, err := ssh.ParsePrivateKey(keyData)
if err != nil {
return nil, fmt.Errorf("parse SSH private key: %w", err)
}
return ssh.PublicKeys(signer), nil
}
// LocalPortForward sets up local port forwarding, binding to localAddr and forwarding to remoteAddr
func (c *Client) LocalPortForward(ctx context.Context, localAddr, remoteAddr string) error {
localListener, err := net.Listen("tcp", localAddr)
if err != nil {
return fmt.Errorf("listen on %s: %w", localAddr, err)
}
go func() {
defer func() {
if err := localListener.Close(); err != nil {
log.Debugf("local listener close error: %v", err)
}
}()
for {
localConn, err := localListener.Accept()
if err != nil {
if ctx.Err() != nil {
return
}
continue
}
go c.handleLocalForward(localConn, remoteAddr)
}
}()
<-ctx.Done()
return ctx.Err()
}
// handleLocalForward handles a single local port forwarding connection
func (c *Client) handleLocalForward(localConn net.Conn, remoteAddr string) {
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
}
}()
channel, err := c.client.Dial("tcp", remoteAddr)
if err != nil {
if strings.Contains(err.Error(), "administratively prohibited") {
_, _ = fmt.Fprintf(os.Stderr, "channel open failed: administratively prohibited: port forwarding is disabled\n")
} else {
log.Debugf("local port forwarding to %s failed: %v", remoteAddr, err)
}
return
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
}
}()
go func() {
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("local forward copy error (local->remote): %v", err)
}
}()
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("local forward copy error (remote->local): %v", err)
}
}
// RemotePortForward sets up remote port forwarding, binding on remote and forwarding to localAddr
func (c *Client) RemotePortForward(ctx context.Context, remoteAddr, localAddr string) error {
host, port, err := c.parseRemoteAddress(remoteAddr)
if err != nil {
return err
}
req := c.buildTCPIPForwardRequest(host, port)
if err := c.sendTCPIPForwardRequest(req); err != nil {
return err
}
go c.handleRemoteForwardChannels(ctx, localAddr)
<-ctx.Done()
if err := c.cancelTCPIPForwardRequest(req); err != nil {
return fmt.Errorf("cancel tcpip-forward: %w", err)
}
return ctx.Err()
}
// parseRemoteAddress parses host and port from remote address string
func (c *Client) parseRemoteAddress(remoteAddr string) (string, uint32, error) {
host, portStr, err := net.SplitHostPort(remoteAddr)
if err != nil {
return "", 0, fmt.Errorf("parse remote address %s: %w", remoteAddr, err)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return "", 0, fmt.Errorf("parse remote port %s: %w", portStr, err)
}
return host, uint32(port), nil
}
// buildTCPIPForwardRequest creates a tcpip-forward request message
func (c *Client) buildTCPIPForwardRequest(host string, port uint32) tcpipForwardMsg {
return tcpipForwardMsg{
Host: host,
Port: port,
}
}
// sendTCPIPForwardRequest sends the tcpip-forward request to establish remote port forwarding
func (c *Client) sendTCPIPForwardRequest(req tcpipForwardMsg) error {
ok, _, err := c.client.SendRequest("tcpip-forward", true, ssh.Marshal(&req))
if err != nil {
return fmt.Errorf("send tcpip-forward request: %w", err)
}
if !ok {
return fmt.Errorf("remote port forwarding denied by server (check if --allow-ssh-remote-port-forwarding is enabled)")
}
return nil
}
// cancelTCPIPForwardRequest cancels the tcpip-forward request
func (c *Client) cancelTCPIPForwardRequest(req tcpipForwardMsg) error {
_, _, err := c.client.SendRequest("cancel-tcpip-forward", true, ssh.Marshal(&req))
if err != nil {
return fmt.Errorf("send cancel-tcpip-forward request: %w", err)
}
return nil
}
// handleRemoteForwardChannels handles incoming forwarded-tcpip channels
func (c *Client) handleRemoteForwardChannels(ctx context.Context, localAddr string) {
// Get the channel once - subsequent calls return nil!
channelRequests := c.client.HandleChannelOpen("forwarded-tcpip")
if channelRequests == nil {
log.Debugf("forwarded-tcpip channel type already being handled")
return
}
for {
select {
case <-ctx.Done():
return
case newChan := <-channelRequests:
if newChan != nil {
go c.handleRemoteForwardChannel(newChan, localAddr)
}
}
}
}
// handleRemoteForwardChannel handles a single forwarded-tcpip channel
func (c *Client) handleRemoteForwardChannel(newChan ssh.NewChannel, localAddr string) {
channel, reqs, err := newChan.Accept()
if err != nil {
return
}
defer func() {
if err := channel.Close(); err != nil {
log.Debugf("remote channel close error: %v", err)
}
}()
go ssh.DiscardRequests(reqs)
localConn, err := net.Dial("tcp", localAddr)
if err != nil {
return
}
defer func() {
if err := localConn.Close(); err != nil {
log.Debugf("local connection close error: %v", err)
}
}()
go func() {
if _, err := io.Copy(localConn, channel); err != nil {
log.Debugf("remote forward copy error (remote->local): %v", err)
}
}()
if _, err := io.Copy(channel, localConn); err != nil {
log.Debugf("remote forward copy error (local->remote): %v", err)
}
}
// tcpipForwardMsg represents the structure for tcpip-forward requests
type tcpipForwardMsg struct {
Host string
Port uint32
}

View File

@@ -0,0 +1,468 @@
package client
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
)
func TestSSHClient_DialWithKey(t *testing.T) {
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create and start server
server := sshserver.New(hostKey)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Test Dial
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
require.NoError(t, err)
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
// Verify client is connected
assert.NotNil(t, client.client)
}
func TestSSHClient_CommandExecution(t *testing.T) {
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
t.Run("ExecuteCommand captures output", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo hello")
assert.NoError(t, err)
assert.Contains(t, string(output), "hello")
})
t.Run("ExecuteCommandWithIO streams output", func(t *testing.T) {
err := client.ExecuteCommandWithIO(ctx, "echo world")
assert.NoError(t, err)
})
t.Run("commands with flags work", func(t *testing.T) {
output, err := client.ExecuteCommand(ctx, "echo -n test_flag")
assert.NoError(t, err)
assert.Equal(t, "test_flag", strings.TrimSpace(string(output)))
})
t.Run("non-zero exit codes don't return errors", func(t *testing.T) {
var testCmd string
if runtime.GOOS == "windows" {
testCmd = "echo hello | Select-String notfound"
} else {
testCmd = "echo 'hello' | grep 'notfound'"
}
_, err := client.ExecuteCommand(ctx, testCmd)
assert.NoError(t, err)
})
}
func TestSSHClient_ConnectionHandling(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Generate client key for multiple connections
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("multi-peer", string(clientPubKey))
require.NoError(t, err)
const numClients = 3
clients := make([]*Client, numClients)
for i := 0; i < numClients; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i))
cancel()
require.NoError(t, err, "Client %d should connect successfully", i)
clients[i] = client
}
for i, client := range clients {
err := client.Close()
assert.NoError(t, err, "Client %d should close without error", i)
}
}
func TestSSHClient_ContextCancellation(t *testing.T) {
server, serverAddr, _ := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("cancel-peer", string(clientPubKey))
require.NoError(t, err)
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := getCurrentUsername()
_, err = DialInsecure(ctx, serverAddr, currentUser)
if err != nil {
assert.Contains(t, err.Error(), "context")
}
})
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
cmdCtx, cmdCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cmdCancel()
err = client.ExecuteCommandWithPTY(cmdCtx, "sleep 10")
if err != nil {
var exitMissingErr *cryptossh.ExitMissingError
isValidCancellation := errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
errors.As(err, &exitMissingErr)
assert.True(t, isValidCancellation, "Should handle command cancellation properly")
}
})
}
func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
server := sshserver.New(hostKey)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
client, err := DialInsecure(ctx, serverAddr, currentUser)
assert.NoError(t, err)
if client != nil {
require.NoError(t, client.Close(), "Client should close without error")
}
})
}
func TestSSHClient_TerminalState(t *testing.T) {
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
assert.Nil(t, client.terminalState)
assert.Equal(t, 0, client.terminalFd)
client.restoreTerminal()
assert.Nil(t, client.terminalState)
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
err := client.OpenTerminal(ctx)
// In test environment without a real terminal, this may complete quickly or timeout
// Both behaviors are acceptable for testing terminal state management
if err != nil {
if runtime.GOOS == "windows" {
assert.True(t,
strings.Contains(err.Error(), "context deadline exceeded") ||
strings.Contains(err.Error(), "console"),
"Should timeout or have console error on Windows")
} else {
// On Unix systems in test environment, we may get various errors
// including timeouts or terminal-related errors
assert.True(t,
strings.Contains(err.Error(), "context deadline exceeded") ||
strings.Contains(err.Error(), "terminal") ||
strings.Contains(err.Error(), "pty"),
"Expected timeout or terminal-related error, got: %v", err)
}
}
}
func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Client) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := sshserver.New(hostKey)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
require.NoError(t, err)
return server, serverAddr, client
}
func TestSSHClient_PortForwarding(t *testing.T) {
server, _, client := setupTestSSHServerAndClient(t)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
defer func() {
err := client.Close()
assert.NoError(t, err)
}()
t.Run("local forwarding times out gracefully", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
err := client.LocalPortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
assert.Error(t, err)
assert.True(t,
errors.Is(err, context.DeadlineExceeded) ||
errors.Is(err, context.Canceled) ||
strings.Contains(err.Error(), "connection"),
"Expected context or connection error")
})
t.Run("remote forwarding denied", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := client.RemotePortForward(ctx, "127.0.0.1:0", "127.0.0.1:8080")
assert.Error(t, err)
assert.True(t,
strings.Contains(err.Error(), "denied") ||
strings.Contains(err.Error(), "disabled"),
"Should be denied by default")
})
t.Run("invalid addresses fail", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
err := client.LocalPortForward(ctx, "invalid:address", "127.0.0.1:8080")
assert.Error(t, err)
err = client.LocalPortForward(ctx, "127.0.0.1:0", "invalid:address")
assert.Error(t, err)
})
}
func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
if testing.Short() {
t.Skip("Skipping data transfer test in short mode")
}
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := sshserver.New(hostKey)
server.SetAllowLocalPortForwarding(true)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
t.Logf("client close error: %v", err)
}
}()
testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer func() {
if err := testServer.Close(); err != nil {
t.Logf("test server close error: %v", err)
}
}()
testServerAddr := testServer.Addr().String()
expectedResponse := "Hello, World!"
go func() {
for {
conn, err := testServer.Accept()
if err != nil {
return
}
go func(c net.Conn) {
defer func() {
if err := c.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
buf := make([]byte, 1024)
if _, err := c.Read(buf); err != nil {
t.Logf("connection read error: %v", err)
return
}
if _, err := c.Write([]byte(expectedResponse)); err != nil {
t.Logf("connection write error: %v", err)
}
}(conn)
}
}()
localListener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
localAddr := localListener.Addr().String()
if err := localListener.Close(); err != nil {
t.Logf("local listener close error: %v", err)
}
ctx, cancel = context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
go func() {
err := client.LocalPortForward(ctx, localAddr, testServerAddr)
if err != nil && !errors.Is(err, context.Canceled) {
t.Logf("Port forward error: %v", err)
}
}()
time.Sleep(100 * time.Millisecond)
conn, err := net.DialTimeout("tcp", localAddr, 2*time.Second)
require.NoError(t, err)
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
_, err = conn.Write([]byte("test"))
require.NoError(t, err)
if err := conn.SetReadDeadline(time.Now().Add(2 * time.Second)); err != nil {
t.Logf("set read deadline error: %v", err)
}
response := make([]byte, len(expectedResponse))
n, err := io.ReadFull(conn, response)
require.NoError(t, err)
assert.Equal(t, len(expectedResponse), n)
assert.Equal(t, expectedResponse, string(response))
}
// getCurrentUsername returns the current username for SSH connections
func getCurrentUsername() string {
if runtime.GOOS == "windows" {
if currentUser, err := user.Current(); err == nil {
username := currentUser.Username
if idx := strings.LastIndex(username, "\\"); idx != -1 {
username = username[idx+1:]
}
return strings.ToLower(username)
}
}
if username := os.Getenv("USER"); username != "" {
return username
}
if currentUser, err := user.Current(); err == nil {
return currentUser.Username
}
return "test-user"
}

View File

@@ -0,0 +1,135 @@
//go:build !windows
package client
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"golang.org/x/term"
)
func (c *Client) setupTerminalMode(ctx context.Context, session *ssh.Session) error {
fd := int(os.Stdout.Fd())
if !term.IsTerminal(fd) {
return c.setupNonTerminalMode(ctx, session)
}
state, err := term.MakeRaw(fd)
if err != nil {
return c.setupNonTerminalMode(ctx, session)
}
c.terminalState = state
c.terminalFd = fd
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT)
go func() {
defer signal.Stop(sigChan)
select {
case <-ctx.Done():
if err := term.Restore(fd, state); err != nil {
log.Debugf("restore terminal state: %v", err)
}
case sig := <-sigChan:
if err := term.Restore(fd, state); err != nil {
log.Debugf("restore terminal state: %v", err)
}
signal.Reset(sig)
s, ok := sig.(syscall.Signal)
if !ok {
log.Debugf("signal %v is not a syscall.Signal: %T", sig, sig)
return
}
if err := syscall.Kill(syscall.Getpid(), s); err != nil {
log.Debugf("kill process with signal %v: %v", s, err)
}
}
}()
return c.setupTerminal(session, fd)
}
func (c *Client) setupNonTerminalMode(_ context.Context, session *ssh.Session) error {
w, h := 80, 24
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("request pty: %w", err)
}
return nil
}
// restoreWindowsConsoleState is a no-op on Unix systems
func (c *Client) restoreWindowsConsoleState() error {
return nil
}
func (c *Client) setupTerminal(session *ssh.Session, fd int) error {
w, h, err := term.GetSize(fd)
if err != nil {
return fmt.Errorf("get terminal size: %w", err)
}
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
// Ctrl+C
ssh.VINTR: 3,
// Ctrl+\
ssh.VQUIT: 28,
// Backspace
ssh.VERASE: 127,
// Ctrl+U
ssh.VKILL: 21,
// Ctrl+D
ssh.VEOF: 4,
ssh.VEOL: 0,
ssh.VEOL2: 0,
// Ctrl+Q
ssh.VSTART: 17,
// Ctrl+S
ssh.VSTOP: 19,
// Ctrl+Z
ssh.VSUSP: 26,
// Ctrl+O
ssh.VDISCARD: 15,
// Ctrl+R
ssh.VREPRINT: 18,
// Ctrl+W
ssh.VWERASE: 23,
// Ctrl+V
ssh.VLNEXT: 22,
}
terminal := os.Getenv("TERM")
if terminal == "" {
terminal = "xterm-256color"
}
if err := session.RequestPty(terminal, h, w, modes); err != nil {
return fmt.Errorf("request pty: %w", err)
}
return nil
}

View File

@@ -0,0 +1,261 @@
//go:build windows
package client
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
// ConsoleUnavailableError indicates that Windows console handles are not available
// (e.g., in CI environments where stdout/stdin are redirected)
type ConsoleUnavailableError struct {
Operation string
Err error
}
func (e *ConsoleUnavailableError) Error() string {
return fmt.Sprintf("console unavailable for %s: %v", e.Operation, e.Err)
}
func (e *ConsoleUnavailableError) Unwrap() error {
return e.Err
}
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
const (
enableProcessedInput = 0x0001
enableLineInput = 0x0002
enableEchoInput = 0x0004
enableVirtualTerminalProcessing = 0x0004
enableVirtualTerminalInput = 0x0200
)
type coord struct {
x, y int16
}
type smallRect struct {
left, top, right, bottom int16
}
type consoleScreenBufferInfo struct {
size coord
cursorPosition coord
attributes uint16
window smallRect
maximumWindowSize coord
}
func (c *Client) setupTerminalMode(_ context.Context, session *ssh.Session) error {
if err := c.saveWindowsConsoleState(); err != nil {
var consoleErr *ConsoleUnavailableError
if errors.As(err, &consoleErr) {
log.Debugf("console unavailable, continuing with defaults: %v", err)
c.terminalFd = 0
} else {
return fmt.Errorf("save console state: %w", err)
}
}
if err := c.enableWindowsVirtualTerminal(); err != nil {
var consoleErr *ConsoleUnavailableError
if errors.As(err, &consoleErr) {
log.Debugf("virtual terminal unavailable: %v", err)
} else {
return fmt.Errorf("failed to enable virtual terminal: %w", err)
}
}
w, h := c.getWindowsConsoleSize()
modes := ssh.TerminalModes{
ssh.ECHO: 1,
ssh.TTY_OP_ISPEED: 14400,
ssh.TTY_OP_OSPEED: 14400,
ssh.ICRNL: 1,
ssh.OPOST: 1,
ssh.ONLCR: 1,
ssh.ISIG: 1,
ssh.ICANON: 1,
ssh.VINTR: 3, // Ctrl+C
ssh.VQUIT: 28, // Ctrl+\
ssh.VERASE: 127, // Backspace
ssh.VKILL: 21, // Ctrl+U
ssh.VEOF: 4, // Ctrl+D
ssh.VEOL: 0,
ssh.VEOL2: 0,
ssh.VSTART: 17, // Ctrl+Q
ssh.VSTOP: 19, // Ctrl+S
ssh.VSUSP: 26, // Ctrl+Z
ssh.VDISCARD: 15, // Ctrl+O
ssh.VWERASE: 23, // Ctrl+W
ssh.VLNEXT: 22, // Ctrl+V
ssh.VREPRINT: 18, // Ctrl+R
}
return session.RequestPty("xterm-256color", h, w, modes)
}
func (c *Client) saveWindowsConsoleState() error {
defer func() {
if r := recover(); r != nil {
log.Debugf("panic in saveWindowsConsoleState: %v", r)
}
}()
stdout := syscall.Handle(os.Stdout.Fd())
stdin := syscall.Handle(os.Stdin.Fd())
var stdoutMode, stdinMode uint32
ret, _, err := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&stdoutMode)))
if ret == 0 {
log.Debugf("failed to get stdout console mode: %v", err)
return &ConsoleUnavailableError{
Operation: "get stdout console mode",
Err: err,
}
}
ret, _, err = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&stdinMode)))
if ret == 0 {
log.Debugf("failed to get stdin console mode: %v", err)
return &ConsoleUnavailableError{
Operation: "get stdin console mode",
Err: err,
}
}
c.terminalFd = 1
c.windowsStdoutMode = stdoutMode
c.windowsStdinMode = stdinMode
log.Debugf("saved Windows console state - stdout: 0x%04x, stdin: 0x%04x", stdoutMode, stdinMode)
return nil
}
func (c *Client) enableWindowsVirtualTerminal() (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in enableWindowsVirtualTerminal: %v", r)
}
}()
stdout := syscall.Handle(os.Stdout.Fd())
stdin := syscall.Handle(os.Stdin.Fd())
var mode uint32
ret, _, winErr := procGetConsoleMode.Call(uintptr(stdout), uintptr(unsafe.Pointer(&mode)))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "get stdout console mode for VT",
Err: winErr,
}
}
mode |= enableVirtualTerminalProcessing
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdout), uintptr(mode))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "enable virtual terminal processing",
Err: winErr,
}
}
ret, _, winErr = procGetConsoleMode.Call(uintptr(stdin), uintptr(unsafe.Pointer(&mode)))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "get stdin console mode for VT",
Err: winErr,
}
}
mode &= ^uint32(enableLineInput | enableEchoInput | enableProcessedInput)
mode |= enableVirtualTerminalInput
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(mode))
if ret == 0 {
return &ConsoleUnavailableError{
Operation: "set stdin raw mode",
Err: winErr,
}
}
log.Debugf("enabled Windows virtual terminal processing")
return nil
}
func (c *Client) getWindowsConsoleSize() (int, int) {
defer func() {
if r := recover(); r != nil {
log.Debugf("panic in getWindowsConsoleSize: %v", r)
}
}()
stdout := syscall.Handle(os.Stdout.Fd())
var csbi consoleScreenBufferInfo
ret, _, err := procGetConsoleScreenBufferInfo.Call(uintptr(stdout), uintptr(unsafe.Pointer(&csbi)))
if ret == 0 {
log.Debugf("failed to get console buffer info, using defaults: %v", err)
return 80, 24
}
width := int(csbi.window.right - csbi.window.left + 1)
height := int(csbi.window.bottom - csbi.window.top + 1)
log.Debugf("Windows console size: %dx%d", width, height)
return width, height
}
func (c *Client) restoreWindowsConsoleState() error {
var err error
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic in restoreWindowsConsoleState: %v", r)
}
}()
if c.terminalFd != 1 {
return nil
}
stdout := syscall.Handle(os.Stdout.Fd())
stdin := syscall.Handle(os.Stdin.Fd())
ret, _, winErr := procSetConsoleMode.Call(uintptr(stdout), uintptr(c.windowsStdoutMode))
if ret == 0 {
log.Debugf("failed to restore stdout console mode: %v", winErr)
if err == nil {
err = fmt.Errorf("restore stdout console mode: %w", winErr)
}
}
ret, _, winErr = procSetConsoleMode.Call(uintptr(stdin), uintptr(c.windowsStdinMode))
if ret == 0 {
log.Debugf("failed to restore stdin console mode: %v", winErr)
if err == nil {
err = fmt.Errorf("restore stdin console mode: %w", winErr)
}
}
c.terminalFd = 0
c.windowsStdoutMode = 0
c.windowsStdinMode = 0
log.Debugf("restored Windows console state")
return err
}