[client] Increase ssh detection timeout (#4827)

This commit is contained in:
Viktor Liu
2025-11-20 17:09:22 +01:00
committed by GitHub
parent 68f56b797d
commit 1311364397
6 changed files with 46 additions and 35 deletions

View File

@@ -749,7 +749,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
logOutput = firstLogFile
}
if err := util.InitLog(logLevel, logOutput); err != nil {
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
return fmt.Errorf("init log: %w", err)
}
@@ -788,7 +790,8 @@ var sshDetectCmd = &cobra.Command{
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
if err := util.InitLog(detectLogLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
@@ -797,15 +800,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
port, err := strconv.Atoi(portStr)
if err != nil {
log.Debugf("invalid port %q: %v", portStr, err)
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
log.Debugf("SSH server detection failed: %v", err)
cancel()
os.Exit(detection.ServerTypeRegular.ExitCode())
}
cancel()
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -343,10 +343,13 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port)
if err != nil {
return nil, fmt.Errorf("SSH server detection failed: %w", err)
return nil, fmt.Errorf("SSH server detection: %w", err)
}
if !serverType.RequiresJWT() {

View File

@@ -189,12 +189,7 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
if runtime.GOOS == "windows" {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
} else {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
}
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"

View File

@@ -3,6 +3,7 @@ package detection
import (
"bufio"
"context"
"fmt"
"net"
"strconv"
"strings"
@@ -19,8 +20,8 @@ const (
// JWTRequiredMarker is appended to responses when JWT is required
JWTRequiredMarker = "NetBird-JWT-Required"
// Timeout is the timeout for SSH server detection
Timeout = 5 * time.Second
// DefaultTimeout is the default timeout for SSH server detection
DefaultTimeout = 5 * time.Second
)
type ServerType string
@@ -61,21 +62,20 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
if err != nil {
log.Debugf("SSH connection failed for detection: %v", err)
return ServerTypeRegular, nil
return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err)
}
defer conn.Close()
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
log.Debugf("set read deadline: %v", err)
return ServerTypeRegular, nil
if deadline, ok := ctx.Deadline(); ok {
if err := conn.SetReadDeadline(deadline); err != nil {
return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err)
}
}
reader := bufio.NewReader(conn)
serverBanner, err := reader.ReadString('\n')
if err != nil {
log.Debugf("read SSH banner: %v", err)
return ServerTypeRegular, nil
return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err)
}
serverBanner = strings.TrimSpace(serverBanner)

View File

@@ -58,7 +58,7 @@ func TestJWTEnforcement(t *testing.T) {
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
if err != nil {
t.Logf("Detection failed: %v", err)
@@ -93,7 +93,7 @@ func TestJWTEnforcement(t *testing.T) {
portNoJWT, err := strconv.Atoi(portStrNoJWT)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
@@ -218,7 +218,7 @@ func TestJWTDetection(t *testing.T) {
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)

View File

@@ -19,9 +19,10 @@ import (
)
const (
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
defaultLogLevel = "warn"
clientStartTimeout = 30 * time.Second
clientStopTimeout = 10 * time.Second
defaultLogLevel = "warn"
defaultSSHDetectionTimeout = 20 * time.Second
)
func main() {
@@ -207,11 +208,19 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
host := args[0].String()
port := args[1].Int()
timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds())
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
timeoutMs = args[2].Int()
if timeoutMs <= 0 {
return js.ValueOf("error: timeout must be positive")
}
}
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
defer cancel()
serverType, err := detectSSHServerType(ctx, client, host, port)
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
if err != nil {
reject.Invoke(err.Error())
return
@@ -222,11 +231,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
})
}
// detectSSHServerType detects SSH server type using NetBird network connection
func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) {
return sshdetection.DetectSSHServerType(ctx, client, host, port)
}
// createClientObject wraps the NetBird client in a JavaScript object
func createClientObject(client *netbird.Client) js.Value {
obj := make(map[string]interface{})