[client] Open browser for ssh automatically (#4838)

This commit is contained in:
Viktor Liu
2025-11-26 16:06:47 +01:00
committed by GitHub
parent f31bba87b4
commit 02200d790b
7 changed files with 107 additions and 39 deletions

View File

@@ -4,14 +4,12 @@ import (
"context" "context"
"fmt" "fmt"
"os" "os"
"os/exec"
"os/user" "os/user"
"runtime" "runtime"
"strings" "strings"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -373,21 +371,13 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {
if err := openBrowser(verificationURIComplete); err != nil { if err := util.OpenBrowser(verificationURIComplete); err != nil {
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" + cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
"https://docs.netbird.io/how-to/register-machines-using-setup-keys") "https://docs.netbird.io/how-to/register-machines-using-setup-keys")
} }
} }
} }
// openBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func openBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// isUnixRunningDesktop checks if a Linux OS is running desktop environment // isUnixRunningDesktop checks if a Linux OS is running desktop environment
func isUnixRunningDesktop() bool { func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" { if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {

View File

@@ -51,6 +51,7 @@ var (
identityFile string identityFile string
skipCachedToken bool skipCachedToken bool
requestPTY bool requestPTY bool
sshNoBrowser bool
) )
var ( var (
@@ -81,6 +82,7 @@ func init() {
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)") sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file (deprecated)")
_ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used") _ = sshCmd.PersistentFlags().MarkDeprecated("identity", "this flag is no longer used")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().BoolVar(&sshNoBrowser, noBrowserFlag, false, noBrowserDesc)
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
@@ -185,6 +187,21 @@ func getEnvOrDefault(flagName, defaultValue string) string {
return defaultValue return defaultValue
} }
// getBoolEnvOrDefault checks for boolean environment variables with WT_ and NB_ prefixes
func getBoolEnvOrDefault(flagName string, defaultValue bool) bool {
if envValue := os.Getenv("WT_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
if envValue := os.Getenv("NB_" + flagName); envValue != "" {
if parsed, err := strconv.ParseBool(envValue); err == nil {
return parsed
}
}
return defaultValue
}
// resetSSHGlobals sets SSH globals to their default values // resetSSHGlobals sets SSH globals to their default values
func resetSSHGlobals() { func resetSSHGlobals() {
port = sshserver.DefaultSSHPort port = sshserver.DefaultSSHPort
@@ -196,6 +213,7 @@ func resetSSHGlobals() {
strictHostKeyChecking = true strictHostKeyChecking = true
knownHostsFile = "" knownHostsFile = ""
identityFile = "" identityFile = ""
sshNoBrowser = false
} }
// parseCustomSSHFlags extracts -L, -R flags and returns filtered args // parseCustomSSHFlags extracts -L, -R flags and returns filtered args
@@ -370,6 +388,7 @@ type sshFlags struct {
KnownHostsFile string KnownHostsFile string
IdentityFile string IdentityFile string
SkipCachedToken bool SkipCachedToken bool
NoBrowser bool
ConfigPath string ConfigPath string
LogLevel string LogLevel string
LocalForwards []string LocalForwards []string
@@ -381,6 +400,7 @@ type sshFlags struct {
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) { func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath) defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
defaultNoBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil) fs.SetOutput(nil)
@@ -401,6 +421,7 @@ func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication") fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
fs.BoolVar(&flags.NoBrowser, "no-browser", defaultNoBrowser, noBrowserDesc)
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "config", defaultConfigPath, "Netbird config file location")
@@ -449,6 +470,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
knownHostsFile = flags.KnownHostsFile knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken skipCachedToken = flags.SkipCachedToken
sshNoBrowser = flags.NoBrowser
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) { if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath configPath = flags.ConfigPath
@@ -508,6 +530,7 @@ func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
DaemonAddr: daemonAddr, DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken, SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking, InsecureSkipVerify: !strictHostKeyChecking,
NoBrowser: sshNoBrowser,
}) })
if err != nil { if err != nil {
@@ -763,7 +786,15 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
return fmt.Errorf("invalid port: %s", portStr) return fmt.Errorf("invalid port: %s", portStr)
} }
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr()) // Check env var for browser setting since this command is invoked via SSH ProxyCommand
// where command-line flags cannot be passed. Default is to open browser.
noBrowser := getBoolEnvOrDefault("NO_BROWSER", false)
var browserOpener func(string) error
if !noBrowser {
browserOpener = util.OpenBrowser
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr(), browserOpener)
if err != nil { if err != nil {
return fmt.Errorf("create SSH proxy: %w", err) return fmt.Errorf("create SSH proxy: %w", err)
} }

View File

@@ -24,6 +24,7 @@ import (
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection" "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/util"
) )
const ( const (
@@ -278,6 +279,7 @@ type DialOptions struct {
DaemonAddr string DaemonAddr string
SkipCachedToken bool SkipCachedToken bool
InsecureSkipVerify bool InsecureSkipVerify bool
NoBrowser bool
} }
// Dial connects to the given ssh server with specified options // Dial connects to the given ssh server with specified options
@@ -307,7 +309,7 @@ func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, er
config.Auth = append(config.Auth, authMethod) config.Auth = append(config.Auth, authMethod)
} }
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken) return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken, opts.NoBrowser)
} }
// dialSSH establishes an SSH connection without JWT authentication // dialSSH establishes an SSH connection without JWT authentication
@@ -333,7 +335,7 @@ func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig
} }
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection // dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) { func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache, noBrowser bool) (*Client, error) {
host, portStr, err := net.SplitHostPort(addr) host, portStr, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse address %s: %w", addr, err) return nil, fmt.Errorf("parse address %s: %w", addr, err)
@@ -359,7 +361,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout) jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel() defer cancel()
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache) jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache, noBrowser)
if err != nil { if err != nil {
return nil, fmt.Errorf("request JWT token: %w", err) return nil, fmt.Errorf("request JWT token: %w", err)
} }
@@ -369,7 +371,7 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
} }
// requestJWTToken requests a JWT token from the NetBird daemon // requestJWTToken requests a JWT token from the NetBird daemon
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) { func requestJWTToken(ctx context.Context, daemonAddr string, skipCache, noBrowser bool) (string, error) {
hint := profilemanager.GetLoginHint() hint := profilemanager.GetLoginHint()
conn, err := connectToDaemon(daemonAddr) conn, err := connectToDaemon(daemonAddr)
@@ -379,7 +381,13 @@ func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (st
defer conn.Close() defer conn.Close()
client := proto.NewDaemonServiceClient(conn) client := proto.NewDaemonServiceClient(conn)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint)
var browserOpener func(string) error
if !noBrowser {
browserOpener = util.OpenBrowser
}
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache, hint, browserOpener)
} }
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon // verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon

View File

@@ -67,8 +67,31 @@ func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKe
return VerifyHostKey(storedKeyData, presentedKey, peerAddress) return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
} }
// printAuthInstructions prints authentication instructions to stderr
func printAuthInstructions(stderr io.Writer, authResponse *proto.RequestJWTAuthResponse, browserWillOpen bool) {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
if browserWillOpen {
_, _ = fmt.Fprintln(stderr, "Please do the SSO login in your browser.")
_, _ = fmt.Fprintln(stderr, "If your browser didn't open automatically, use this URL to log in:")
_, _ = fmt.Fprintln(stderr)
}
_, _ = fmt.Fprintf(stderr, "%s\n", authResponse.VerificationURIComplete)
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
}
if browserWillOpen {
_, _ = fmt.Fprintln(stderr)
}
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
}
// RequestJWTToken requests or retrieves a JWT token for SSH authentication // RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) { func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string, openBrowser func(string) error) (string, error) {
req := &proto.RequestJWTAuthRequest{} req := &proto.RequestJWTAuthRequest{}
if hint != "" { if hint != "" {
req.Hint = &hint req.Hint = &hint
@@ -84,12 +107,13 @@ func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdo
} }
if stderr != nil { if stderr != nil {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.") printAuthInstructions(stderr, authResponse, openBrowser != nil)
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete) }
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode) if openBrowser != nil {
if err := openBrowser(authResponse.VerificationURIComplete); err != nil {
log.Debugf("open browser: %v", err)
} }
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
} }
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{ tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{

View File

@@ -35,15 +35,16 @@ const (
) )
type SSHProxy struct { type SSHProxy struct {
daemonAddr string daemonAddr string
targetHost string targetHost string
targetPort int targetPort int
stderr io.Writer stderr io.Writer
conn *grpc.ClientConn conn *grpc.ClientConn
daemonClient proto.DaemonServiceClient daemonClient proto.DaemonServiceClient
browserOpener func(string) error
} }
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) { func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer, browserOpener func(string) error) (*SSHProxy, error) {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://") grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil { if err != nil {
@@ -51,12 +52,13 @@ func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHP
} }
return &SSHProxy{ return &SSHProxy{
daemonAddr: daemonAddr, daemonAddr: daemonAddr,
targetHost: targetHost, targetHost: targetHost,
targetPort: targetPort, targetPort: targetPort,
stderr: stderr, stderr: stderr,
conn: grpcConn, conn: grpcConn,
daemonClient: proto.NewDaemonServiceClient(grpcConn), daemonClient: proto.NewDaemonServiceClient(grpcConn),
browserOpener: browserOpener,
}, nil }, nil
} }
@@ -70,7 +72,7 @@ func (p *SSHProxy) Close() error {
func (p *SSHProxy) Connect(ctx context.Context) error { func (p *SSHProxy) Connect(ctx context.Context) error {
hint := profilemanager.GetLoginHint() hint := profilemanager.GetLoginHint()
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint) jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true, hint, p.browserOpener)
if err != nil { if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err) return fmt.Errorf(jwtAuthErrorMsg, err)
} }

View File

@@ -153,7 +153,7 @@ func TestSSHProxy_Connect(t *testing.T) {
validToken := generateValidJWT(t, privateKey, issuer, audience) validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken) mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil) proxyInstance, err := New(mockDaemon.addr, host, port, nil, nil)
require.NoError(t, err) require.NoError(t, err)
clientConn, proxyConn := net.Pipe() clientConn, proxyConn := net.Pipe()

View File

@@ -1,6 +1,19 @@
package util package util
import "os" import (
"os"
"os/exec"
"github.com/skratchdot/open-golang/open"
)
// OpenBrowser opens the URL in a browser, respecting the BROWSER environment variable.
func OpenBrowser(url string) error {
if browser := os.Getenv("BROWSER"); browser != "" {
return exec.Command(browser, url).Start()
}
return open.Run(url)
}
// SliceDiff returns the elements in slice `x` that are not in slice `y` // SliceDiff returns the elements in slice `x` that are not in slice `y`
func SliceDiff(x, y []string) []string { func SliceDiff(x, y []string) []string {