diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index f97ce5f90..dd9407738 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -9,6 +9,7 @@ import ( "strings" "syscall" + log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/netbirdio/netbird/client/internal" @@ -72,8 +73,8 @@ var sshCmd = &cobra.Command{ go func() { // blocking if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil { + log.Debug(err) os.Exit(1) - // log.Print(err) } cancel() }() @@ -95,7 +96,7 @@ func runSSH(ctx context.Context, addr string, pemKey []byte, cmd *cobra.Command) cmd.Printf("Couldn't connect. Please check the connection status or if the ssh server is enabled on the other peer" + "You can verify the connection by running:\n\n" + " netbird status\n\n") - return nil + return err } go func() { <-ctx.Done() diff --git a/client/ssh/server.go b/client/ssh/server.go index f08c5a2f1..b9128845e 100644 --- a/client/ssh/server.go +++ b/client/ssh/server.go @@ -20,6 +20,9 @@ import ( // DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server const DefaultSSHPort = 44338 +// TerminalTimeout is the timeout for terminal session to be ready +const TerminalTimeout = 10 * time.Second + // DefaultSSHServer is a function that creates DefaultServer func DefaultSSHServer(hostKeyPEM []byte, addr string) (Server, error) { return newDefaultServer(hostKeyPEM, addr) @@ -213,42 +216,24 @@ func (srv *DefaultServer) stdInOut(file *os.File, session ssh.Session) { io.Copy(file, session) }() - // For nodes on AWS the terminal takes a while to be ready so we need to wait - terminalIsReady := make(chan bool) - go func() { - for { - log.Debugf("Checking if terminal is ready") - if checkIfFileIsReady(file) { - terminalIsReady <- true - } - time.Sleep(100 * time.Millisecond) - } - }() - timer := time.NewTimer(30 * time.Second) + timer := time.NewTimer(TerminalTimeout) for { select { case <-timer.C: session.Write([]byte("Reached timeout while opening connection\n")) session.Exit(1) - case <-terminalIsReady: + return + default: // stdout - io.Copy(session, file) - session.Exit(0) + writtenBytes, err := io.Copy(session, file) + if err != nil && writtenBytes != 0 { + session.Exit(0) + return + } } } } -func checkIfFileIsReady(file *os.File) bool { - buffer := make([]byte, 0) - _, err := file.Read(buffer) - // _, err := file.Stat() - // log.Infof("file stat: %v", err) - if err == nil { - return true - } - return false -} - // Start starts SSH server. Blocking func (srv *DefaultServer) Start() error { log.Infof("starting SSH server on addr: %s", srv.listener.Addr().String())