Address review

This commit is contained in:
Viktor Liu
2025-08-26 21:00:33 +02:00
parent b1a9242c98
commit cdd5c6c005
3 changed files with 28 additions and 54 deletions

View File

@@ -54,7 +54,7 @@ func (c *Client) OpenTerminal(ctx context.Context) error {
return err
}
c.setupSessionIO(ctx, session)
c.setupSessionIO(session)
if err := session.Shell(); err != nil {
return fmt.Errorf("start shell: %w", err)
@@ -64,7 +64,7 @@ func (c *Client) OpenTerminal(ctx context.Context) error {
}
// setupSessionIO connects session streams to local terminal
func (c *Client) setupSessionIO(ctx context.Context, session *ssh.Session) {
func (c *Client) setupSessionIO(session *ssh.Session) {
session.Stdout = os.Stdout
session.Stderr = os.Stderr
session.Stdin = os.Stdin
@@ -143,7 +143,7 @@ func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error
}
defer cleanup()
c.setupSessionIO(ctx, session)
c.setupSessionIO(session)
if err := session.Start(command); err != nil {
return fmt.Errorf("start command: %w", err)
@@ -180,7 +180,7 @@ func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) erro
return fmt.Errorf("setup terminal mode: %w", err)
}
c.setupSessionIO(ctx, session)
c.setupSessionIO(session)
if err := session.Start(command); err != nil {
return fmt.Errorf("start command: %w", err)

View File

@@ -1,5 +1,3 @@
//go:build windows
package client
import (
@@ -14,6 +12,21 @@ import (
"golang.org/x/crypto/ssh"
)
const (
enableProcessedInput = 0x0001
enableLineInput = 0x0002
enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
enableVirtualTerminalInput = 0x0200
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
// ConsoleUnavailableError indicates that Windows console handles are not available
// (e.g., in CI environments where stdout/stdin are redirected)
type ConsoleUnavailableError struct {
@@ -29,21 +42,6 @@ func (e *ConsoleUnavailableError) Unwrap() error {
return e.Err
}
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
)
const (
enableProcessedInput = 0x0001
enableLineInput = 0x0002
enableEchoInput = 0x0004
enableVirtualTerminalProcessing = 0x0004
enableVirtualTerminalInput = 0x0200
)
type coord struct {
x, y int16
}

View File

@@ -89,7 +89,6 @@ type sshConnectionState struct {
// Server is the SSH server implementation
type Server struct {
listener net.Listener
sshServer *ssh.Server
authorizedKeys map[string]ssh.PublicKey
mu sync.RWMutex
@@ -138,16 +137,20 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return fmt.Errorf("create listener: %w", err)
}
sshServer, err := s.createSSHServer(ln)
sshServer, err := s.createSSHServer(ln.Addr())
if err != nil {
s.cleanupOnError(ln)
return fmt.Errorf("create SSH server: %w", err)
}
s.initializeServerState(ln, sshServer)
s.sshServer = sshServer
log.Infof("SSH server started on %s", addrDesc)
go s.serve(ln, sshServer)
go func() {
if err := sshServer.Serve(ln); !isShutdownError(err) {
log.Errorf("SSH server error: %v", err)
}
}()
return nil
}
@@ -186,12 +189,6 @@ func (s *Server) cleanupOnError(ln net.Listener) {
s.closeListener(ln)
}
// initializeServerState sets up server state after successful setup
func (s *Server) initializeServerState(ln net.Listener, sshServer *ssh.Server) {
s.listener = ln
s.sshServer = sshServer
}
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
@@ -206,7 +203,6 @@ func (s *Server) Stop() error {
}
s.sshServer = nil
s.listener = nil
return nil
}
@@ -254,7 +250,6 @@ func (s *Server) SetSocketFilter(ifIdx int) {
s.ifIdx = ifIdx
}
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -370,25 +365,6 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
return conn
}
// serve runs the SSH server in a goroutine
func (s *Server) serve(ln net.Listener, sshServer *ssh.Server) {
if ln == nil {
log.Debug("SSH server serve called with nil listener")
return
}
err := sshServer.Serve(ln)
if err == nil {
return
}
if isShutdownError(err) {
return
}
log.Errorf("SSH server error: %v", err)
}
// isShutdownError checks if the error is expected during normal shutdown
func isShutdownError(err error) bool {
if errors.Is(err, net.ErrClosed) {
@@ -404,13 +380,13 @@ func isShutdownError(err error) bool {
}
// createSSHServer creates and configures the SSH server
func (s *Server) createSSHServer(listener net.Listener) (*ssh.Server, error) {
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err)
}
server := &ssh.Server{
Addr: listener.Addr().String(),
Addr: addr.String(),
Handler: s.sessionHandler,
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": s.sftpSubsystemHandler,