From 02200d790bb8a5466f3b824ea6c6c2e624240b4a Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Wed, 26 Nov 2025 16:06:47 +0100 Subject: [PATCH] [client] Open browser for ssh automatically (#4838) --- client/cmd/login.go | 12 +----------- client/cmd/ssh.go | 33 ++++++++++++++++++++++++++++++- client/ssh/client/client.go | 18 ++++++++++++----- client/ssh/common.go | 36 ++++++++++++++++++++++++++++------ client/ssh/proxy/proxy.go | 30 +++++++++++++++------------- client/ssh/proxy/proxy_test.go | 2 +- util/common.go | 15 +++++++++++++- 7 files changed, 107 insertions(+), 39 deletions(-) diff --git a/client/cmd/login.go b/client/cmd/login.go index 2ddcccc8a..a34bb7c70 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -4,14 +4,12 @@ import ( "context" "fmt" "os" - "os/exec" "os/user" "runtime" "strings" "time" log "github.com/sirupsen/logrus" - "github.com/skratchdot/open-golang/open" "github.com/spf13/cobra" "google.golang.org/grpc/codes" gstatus "google.golang.org/grpc/status" @@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro cmd.Println("") if !noBrowser { - if err := openBrowser(verificationURIComplete); err != nil { + if err := util.OpenBrowser(verificationURIComplete); err != nil { cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + "https://docs.netbird.io/how-to/register-machines-using-setup-keys") } } } -// openBrowser opens the URL in a browser, respecting the BROWSER environment variable. -func openBrowser(url string) error { - if browser := os.Getenv("BROWSER"); browser != "" { - return exec.Command(browser, url).Start() - } - return open.Run(url) -} - // isUnixRunningDesktop checks if a Linux OS is running desktop environment func isUnixRunningDesktop() bool { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { diff --git a/client/cmd/ssh.go b/client/cmd/ssh.go index 92857c637..525bcdef1 100644 --- a/client/cmd/ssh.go +++ b/client/cmd/ssh.go @@ -51,6 +51,7 @@ var ( identityFile string skipCachedToken bool requestPTY bool + sshNoBrowser bool ) var ( @@ -81,6 +82,7 @@ func init() { sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") _ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used") sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc) sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") @@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string { return defaultValue } +// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes +func getBoolEnvOrDefault(flagName string, defaultValue bool) bool { + if envValue := os.Getenv("WT_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + if envValue := os.Getenv("NB_" + flagName); envValue != "" { + if parsed, err := strconv.ParseBool(envValue); err == nil { + return parsed + } + } + return defaultValue +} + // resetSSHGlobals sets SSH globals to their default values func resetSSHGlobals() { port = sshserver.DefaultSSHPort @@ -196,6 +213,7 @@ func resetSSHGlobals() { strictHostKeyChecking = true knownHostsFile = "" identityFile = "" + sshNoBrowser = false } // parseCustomSSHFlags extracts -L, -R flags and returns filtered args @@ -370,6 +388,7 @@ type sshFlags struct { KnownHostsFile string IdentityFile string SkipCachedToken bool + NoBrowser bool ConfigPath string LogLevel string LocalForwards []string @@ -381,6 +400,7 @@ type sshFlags struct { func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { defaultConfigPath := getEnvOrDefault("CONFIG", configPath) defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) + defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false) fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) fs.SetOutput(nil) @@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") + fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc) fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") @@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error { knownHostsFile = flags.KnownHostsFile identityFile = flags.IdentityFile skipCachedToken = flags.SkipCachedToken + sshNoBrowser = flags.NoBrowser if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { configPath = flags.ConfigPath @@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { DaemonAddr: daemonAddr, SkipCachedToken: skipCachedToken, InsecureSkipVerify: !strictHostKeyChecking, + NoBrowser: sshNoBrowser, }) if err != nil { @@ -763,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid port: %s", portStr) } - proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr()) + // Check env var for browser setting since this command is invoked via SSH ProxyCommand + // where command-line flags cannot be passed. Default is to open browser. + noBrowser := getBoolEnvOrDefault("NO_BROWSER", false) + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener) if err != nil { return fmt.Errorf("create SSH proxy: %w", err) } diff --git a/client/ssh/client/client.go b/client/ssh/client/client.go index 31b80317a..aab222093 100644 --- a/client/ssh/client/client.go +++ b/client/ssh/client/client.go @@ -24,6 +24,7 @@ import ( "github.com/netbirdio/netbird/client/proto" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh/detection" + "github.com/netbirdio/netbird/util" ) const ( @@ -278,6 +279,7 @@ type DialOptions struct { DaemonAddr string SkipCachedToken bool InsecureSkipVerify bool + NoBrowser bool } // Dial connects to the given ssh server with specified options @@ -307,7 +309,7 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er config.Auth = append(config.Auth, authMethod) } - return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken) + return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser) } // dialSSH establishes an SSH connection without JWT authentication @@ -333,7 +335,7 @@ func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig } // dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection -func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) { +func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) { host, portStr, err := net.SplitHostPort(addr) if err != nil { return nil, fmt.Errorf("parse address %s: %w", addr, err) @@ -359,7 +361,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout) defer cancel() - jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache) + jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser) if err != nil { return nil, fmt.Errorf("request JWT token: %w", err) } @@ -369,7 +371,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo } // requestJWTToken requests a JWT token from the NetBird daemon -func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) { +func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) { hint := profilemanager.GetLoginHint() conn, err := connectToDaemon(daemonAddr) @@ -379,7 +381,13 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st defer conn.Close() client := proto.NewDaemonServiceClient(conn) - return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint) + + var browserOpener func(string) error + if !noBrowser { + browserOpener = util.OpenBrowser + } + + return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener) } // verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon diff --git a/client/ssh/common.go b/client/ssh/common.go index 3beb12806..6574437b5 100644 --- a/client/ssh/common.go +++ b/client/ssh/common.go @@ -67,8 +67,31 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe return VerifyHostKey(storedKeyData, presentedKey, peerAddress) } +// printAuthInstructions prints authentication instructions to stderr +func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) { + _, _ = fmt.Fprintln(stderr, "SSH authentication required.") + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.") + _, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:") + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete) + + if authResponse.UserCode != "" { + _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + } + + if browserWillOpen { + _, _ = fmt.Fprintln(stderr) + } + + _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") +} + // RequestJWTToken requests or retrieves a JWT token for SSH authentication -func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) { +func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) { req := &proto.RequestJWTAuthRequest{} if hint != "" { req.Hint = &hint @@ -84,12 +107,13 @@ func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdo } if stderr != nil { - _, _ = fmt.Fprintln(stderr, "SSH authentication required.") - _, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete) - if authResponse.UserCode != "" { - _, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) + printAuthInstructions(stderr, authResponse, openBrowser != nil) + } + + if openBrowser != nil { + if err := openBrowser(authResponse.VerificationURIComplete); err != nil { + log.Debugf("open browser: %v", err) } - _, _ = fmt.Fprintln(stderr, "Waiting for authentication...") } tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{ diff --git a/client/ssh/proxy/proxy.go b/client/ssh/proxy/proxy.go index bc8a84b89..4e807e33c 100644 --- a/client/ssh/proxy/proxy.go +++ b/client/ssh/proxy/proxy.go @@ -35,15 +35,16 @@ const ( ) type SSHProxy struct { - daemonAddr string - targetHost string - targetPort int - stderr io.Writer - conn *grpc.ClientConn - daemonClient proto.DaemonServiceClient + daemonAddr string + targetHost string + targetPort int + stderr io.Writer + conn *grpc.ClientConn + daemonClient proto.DaemonServiceClient + browserOpener func(string) error } -func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) { +func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) { grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://") grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { @@ -51,12 +52,13 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP } return &SSHProxy{ - daemonAddr: daemonAddr, - targetHost: targetHost, - targetPort: targetPort, - stderr: stderr, - conn: grpcConn, - daemonClient: proto.NewDaemonServiceClient(grpcConn), + daemonAddr: daemonAddr, + targetHost: targetHost, + targetPort: targetPort, + stderr: stderr, + conn: grpcConn, + daemonClient: proto.NewDaemonServiceClient(grpcConn), + browserOpener: browserOpener, }, nil } @@ -70,7 +72,7 @@ func (p *SSHProxy) Close() error { func (p *SSHProxy) Connect(ctx context.Context) error { hint := profilemanager.GetLoginHint() - jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint) + jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener) if err != nil { return fmt.Errorf(jwtAuthErrorMsg, err) } diff --git a/client/ssh/proxy/proxy_test.go b/client/ssh/proxy/proxy_test.go index c5036da37..582f9c07b 100644 --- a/client/ssh/proxy/proxy_test.go +++ b/client/ssh/proxy/proxy_test.go @@ -153,7 +153,7 @@ func TestSSHProxy_Connect(t *testing.T) { validToken := generateValidJWT(t, privateKey, issuer, audience) mockDaemon.setJWTToken(validToken) - proxyInstance, err := New(mockDaemon.addr, host, port, nil) + proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil) require.NoError(t, err) clientConn, proxyConn := net.Pipe() diff --git a/util/common.go b/util/common.go index 27adb9d13..89903b609 100644 --- a/util/common.go +++ b/util/common.go @@ -1,6 +1,19 @@ package util -import "os" +import ( + "os" + "os/exec" + + "github.com/skratchdot/open-golang/open" +) + +// OpenBrowser opens the URL in a browser, respecting the BROWSER environment variable. +func OpenBrowser(url string) error { + if browser := os.Getenv("BROWSER"); browser != "" { + return exec.Command(browser, url).Start() + } + return open.Run(url) +} // SliceDiff returns the elements in slice `x` that are not in slice `y` func SliceDiff(x, y []string) []string {