mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-17 15:56:39 +00:00
Address review
This commit is contained in:
@@ -54,7 +54,7 @@ func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
c.setupSessionIO(ctx, session)
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
@@ -64,7 +64,7 @@ func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||
}
|
||||
|
||||
// setupSessionIO connects session streams to local terminal
|
||||
func (c *Client) setupSessionIO(ctx context.Context, session *ssh.Session) {
|
||||
func (c *Client) setupSessionIO(session *ssh.Session) {
|
||||
session.Stdout = os.Stdout
|
||||
session.Stderr = os.Stderr
|
||||
session.Stdin = os.Stdin
|
||||
@@ -143,7 +143,7 @@ func (c *Client) ExecuteCommandWithIO(ctx context.Context, command string) error
|
||||
}
|
||||
defer cleanup()
|
||||
|
||||
c.setupSessionIO(ctx, session)
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
@@ -180,7 +180,7 @@ func (c *Client) ExecuteCommandWithPTY(ctx context.Context, command string) erro
|
||||
return fmt.Errorf("setup terminal mode: %w", err)
|
||||
}
|
||||
|
||||
c.setupSessionIO(ctx, session)
|
||||
c.setupSessionIO(session)
|
||||
|
||||
if err := session.Start(command); err != nil {
|
||||
return fmt.Errorf("start command: %w", err)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
//go:build windows
|
||||
|
||||
package client
|
||||
|
||||
import (
|
||||
@@ -14,6 +12,21 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
enableProcessedInput = 0x0001
|
||||
enableLineInput = 0x0002
|
||||
enableEchoInput = 0x0004 // Input mode: ENABLE_ECHO_INPUT
|
||||
enableVirtualTerminalProcessing = 0x0004 // Output mode: ENABLE_VIRTUAL_TERMINAL_PROCESSING (same value, different mode)
|
||||
enableVirtualTerminalInput = 0x0200
|
||||
)
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
|
||||
)
|
||||
|
||||
// ConsoleUnavailableError indicates that Windows console handles are not available
|
||||
// (e.g., in CI environments where stdout/stdin are redirected)
|
||||
type ConsoleUnavailableError struct {
|
||||
@@ -29,21 +42,6 @@ func (e *ConsoleUnavailableError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
var (
|
||||
kernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procGetConsoleMode = kernel32.NewProc("GetConsoleMode")
|
||||
procSetConsoleMode = kernel32.NewProc("SetConsoleMode")
|
||||
procGetConsoleScreenBufferInfo = kernel32.NewProc("GetConsoleScreenBufferInfo")
|
||||
)
|
||||
|
||||
const (
|
||||
enableProcessedInput = 0x0001
|
||||
enableLineInput = 0x0002
|
||||
enableEchoInput = 0x0004
|
||||
enableVirtualTerminalProcessing = 0x0004
|
||||
enableVirtualTerminalInput = 0x0200
|
||||
)
|
||||
|
||||
type coord struct {
|
||||
x, y int16
|
||||
}
|
||||
|
||||
@@ -89,7 +89,6 @@ type sshConnectionState struct {
|
||||
|
||||
// Server is the SSH server implementation
|
||||
type Server struct {
|
||||
listener net.Listener
|
||||
sshServer *ssh.Server
|
||||
authorizedKeys map[string]ssh.PublicKey
|
||||
mu sync.RWMutex
|
||||
@@ -138,16 +137,20 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
return fmt.Errorf("create listener: %w", err)
|
||||
}
|
||||
|
||||
sshServer, err := s.createSSHServer(ln)
|
||||
sshServer, err := s.createSSHServer(ln.Addr())
|
||||
if err != nil {
|
||||
s.cleanupOnError(ln)
|
||||
return fmt.Errorf("create SSH server: %w", err)
|
||||
}
|
||||
|
||||
s.initializeServerState(ln, sshServer)
|
||||
s.sshServer = sshServer
|
||||
log.Infof("SSH server started on %s", addrDesc)
|
||||
|
||||
go s.serve(ln, sshServer)
|
||||
go func() {
|
||||
if err := sshServer.Serve(ln); !isShutdownError(err) {
|
||||
log.Errorf("SSH server error: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -186,12 +189,6 @@ func (s *Server) cleanupOnError(ln net.Listener) {
|
||||
s.closeListener(ln)
|
||||
}
|
||||
|
||||
// initializeServerState sets up server state after successful setup
|
||||
func (s *Server) initializeServerState(ln net.Listener, sshServer *ssh.Server) {
|
||||
s.listener = ln
|
||||
s.sshServer = sshServer
|
||||
}
|
||||
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
@@ -206,7 +203,6 @@ func (s *Server) Stop() error {
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
s.listener = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -254,7 +250,6 @@ func (s *Server) SetSocketFilter(ifIdx int) {
|
||||
s.ifIdx = ifIdx
|
||||
}
|
||||
|
||||
|
||||
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -370,25 +365,6 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
return conn
|
||||
}
|
||||
|
||||
// serve runs the SSH server in a goroutine
|
||||
func (s *Server) serve(ln net.Listener, sshServer *ssh.Server) {
|
||||
if ln == nil {
|
||||
log.Debug("SSH server serve called with nil listener")
|
||||
return
|
||||
}
|
||||
|
||||
err := sshServer.Serve(ln)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if isShutdownError(err) {
|
||||
return
|
||||
}
|
||||
|
||||
log.Errorf("SSH server error: %v", err)
|
||||
}
|
||||
|
||||
// isShutdownError checks if the error is expected during normal shutdown
|
||||
func isShutdownError(err error) bool {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
@@ -404,13 +380,13 @@ func isShutdownError(err error) bool {
|
||||
}
|
||||
|
||||
// createSSHServer creates and configures the SSH server
|
||||
func (s *Server) createSSHServer(listener net.Listener) (*ssh.Server, error) {
|
||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
if err := enableUserSwitching(); err != nil {
|
||||
log.Warnf("failed to enable user switching: %v", err)
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
Addr: listener.Addr().String(),
|
||||
Addr: addr.String(),
|
||||
Handler: s.sessionHandler,
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": s.sftpSubsystemHandler,
|
||||
|
||||
Reference in New Issue
Block a user