diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index defa16247..ea2fa409a 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -54,7 +54,7 @@ func (c *Client) OpenTerminal(ctx context.Context) error { return err } - c.setupSessionIO(ctx, session) + c.setupSessionIO(session) if err := session.Shell(); err != nil { return fmt.Errorf("start shell: %w", err) @@ -64,7 +64,7 @@ func (c *Client) OpenTerminal(ctx context.Context) error { } // setupSessionIO connects session streams to local terminal -func (c *Client) setupSessionIO(ctx context.Context, session *ssh.Session) { +func (c *Client) setupSessionIO(session *ssh.Session) { session.Stdout = os.Stdout session.Stderr = os.Stderr session.Stdin = os.Stdin @@ -143,7 +143,7 @@ func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error } defer cleanup() - c.setupSessionIO(ctx, session) + c.setupSessionIO(session) if err := session.Start(command); err != nil { return fmt.Errorf("start command: %w", err) @@ -180,7 +180,7 @@ func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) erro return fmt.Errorf("setup terminal mode: %w", err) } - c.setupSessionIO(ctx, session) + c.setupSessionIO(session) if err := session.Start(command); err != nil { return fmt.Errorf("start command: %w", err) diff --git a/client/ssh/client/terminal_windows.go b/client/ssh/client/terminal_windows.go index 84ac7ff56..438d538c4 100644 --- a/client/ssh/client/terminal_windows.go +++ b/client/ssh/client/terminal_windows.go @@ -1,5 +1,3 @@ -//go:build windows - package client import ( @@ -14,6 +12,21 @@ import ( "golang.org/x/crypto/ssh" ) +const ( + enableProcessedInput = 0x0001 + enableLineInput = 0x0002 + enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT + enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode) + enableVirtualTerminalInput = 0x0200 +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetConsoleMode = kernel32.NewProc("GetConsoleMode") + procSetConsoleMode = kernel32.NewProc("SetConsoleMode") + procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") +) + // ConsoleUnavailableError indicates that Windows console handles are not available // (e.g., in CI environments where stdout/stdin are redirected) type ConsoleUnavailableError struct { @@ -29,21 +42,6 @@ func (e *ConsoleUnavailableError) Unwrap() error { return e.Err } -var ( - kernel32 = syscall.NewLazyDLL("kernel32.dll") - procGetConsoleMode = kernel32.NewProc("GetConsoleMode") - procSetConsoleMode = kernel32.NewProc("SetConsoleMode") - procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo") -) - -const ( - enableProcessedInput = 0x0001 - enableLineInput = 0x0002 - enableEchoInput = 0x0004 - enableVirtualTerminalProcessing = 0x0004 - enableVirtualTerminalInput = 0x0200 -) - type coord struct { x, y int16 } diff --git a/client/ssh/server/server.go b/client/ssh/server/server.go index f8830f972..5f680cea5 100644 --- a/client/ssh/server/server.go +++ b/client/ssh/server/server.go @@ -89,7 +89,6 @@ type sshConnectionState struct { // Server is the SSH server implementation type Server struct { - listener net.Listener sshServer *ssh.Server authorizedKeys map[string]ssh.PublicKey mu sync.RWMutex @@ -138,16 +137,20 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { return fmt.Errorf("create listener: %w", err) } - sshServer, err := s.createSSHServer(ln) + sshServer, err := s.createSSHServer(ln.Addr()) if err != nil { s.cleanupOnError(ln) return fmt.Errorf("create SSH server: %w", err) } - s.initializeServerState(ln, sshServer) + s.sshServer = sshServer log.Infof("SSH server started on %s", addrDesc) - go s.serve(ln, sshServer) + go func() { + if err := sshServer.Serve(ln); !isShutdownError(err) { + log.Errorf("SSH server error: %v", err) + } + }() return nil } @@ -186,12 +189,6 @@ func (s *Server) cleanupOnError(ln net.Listener) { s.closeListener(ln) } -// initializeServerState sets up server state after successful setup -func (s *Server) initializeServerState(ln net.Listener, sshServer *ssh.Server) { - s.listener = ln - s.sshServer = sshServer -} - // Stop closes the SSH server func (s *Server) Stop() error { s.mu.Lock() @@ -206,7 +203,6 @@ func (s *Server) Stop() error { } s.sshServer = nil - s.listener = nil return nil } @@ -254,7 +250,6 @@ func (s *Server) SetSocketFilter(ifIdx int) { s.ifIdx = ifIdx } - func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { s.mu.RLock() defer s.mu.RUnlock() @@ -370,25 +365,6 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { return conn } -// serve runs the SSH server in a goroutine -func (s *Server) serve(ln net.Listener, sshServer *ssh.Server) { - if ln == nil { - log.Debug("SSH server serve called with nil listener") - return - } - - err := sshServer.Serve(ln) - if err == nil { - return - } - - if isShutdownError(err) { - return - } - - log.Errorf("SSH server error: %v", err) -} - // isShutdownError checks if the error is expected during normal shutdown func isShutdownError(err error) bool { if errors.Is(err, net.ErrClosed) { @@ -404,13 +380,13 @@ func isShutdownError(err error) bool { } // createSSHServer creates and configures the SSH server -func (s *Server) createSSHServer(listener net.Listener) (*ssh.Server, error) { +func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { if err := enableUserSwitching(); err != nil { log.Warnf("failed to enable user switching: %v", err) } server := &ssh.Server{ - Addr: listener.Addr().String(), + Addr: addr.String(), Handler: s.sessionHandler, SubsystemHandlers: map[string]ssh.SubsystemHandler{ "sftp": s.sftpSubsystemHandler,