mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Increase ssh detection timeout (#4827)
This commit is contained in:
@@ -749,7 +749,9 @@ func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != defaultLogFile {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
|
||||
proxyLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
if err := util.InitLog(proxyLogLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
}
|
||||
|
||||
@@ -788,7 +790,8 @@ var sshDetectCmd = &cobra.Command{
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
detectLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
if err := util.InitLog(detectLogLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
@@ -797,15 +800,21 @@ func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
log.Debugf("invalid port %q: %v", portStr, err)
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
ctx, cancel := context.WithTimeout(cmd.Context(), detection.DefaultTimeout)
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
log.Debugf("SSH server detection failed: %v", err)
|
||||
cancel()
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
cancel()
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -343,10 +343,13 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
|
||||
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||
return nil, fmt.Errorf("SSH server detection: %w", err)
|
||||
}
|
||||
|
||||
if !serverType.RequiresJWT() {
|
||||
|
||||
@@ -189,12 +189,7 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
} else {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||
}
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
|
||||
@@ -3,6 +3,7 @@ package detection
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,8 +20,8 @@ const (
|
||||
// JWTRequiredMarker is appended to responses when JWT is required
|
||||
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||
|
||||
// Timeout is the timeout for SSH server detection
|
||||
Timeout = 5 * time.Second
|
||||
// DefaultTimeout is the default timeout for SSH server detection
|
||||
DefaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
@@ -61,21 +62,20 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Debugf("SSH connection failed for detection: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||
log.Debugf("set read deadline: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
serverBanner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("read SSH banner: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err)
|
||||
}
|
||||
|
||||
serverBanner = strings.TrimSpace(serverBanner)
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
if err != nil {
|
||||
t.Logf("Detection failed: %v", err)
|
||||
@@ -93,7 +93,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||
@@ -218,7 +218,7 @@ func TestJWTDetection(t *testing.T) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||
|
||||
@@ -19,9 +19,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
clientStartTimeout = 30 * time.Second
|
||||
clientStopTimeout = 10 * time.Second
|
||||
defaultLogLevel = "warn"
|
||||
defaultSSHDetectionTimeout = 20 * time.Second
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -207,11 +208,19 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
|
||||
timeoutMs := int(defaultSSHDetectionTimeout.Milliseconds())
|
||||
if len(args) >= 3 && !args[2].IsNull() && !args[2].IsUndefined() {
|
||||
timeoutMs = args[2].Int()
|
||||
if timeoutMs <= 0 {
|
||||
return js.ValueOf("error: timeout must be positive")
|
||||
}
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutMs)*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
serverType, err := detectSSHServerType(ctx, client, host, port)
|
||||
serverType, err := sshdetection.DetectSSHServerType(ctx, client, host, port)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
@@ -222,11 +231,6 @@ func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
})
|
||||
}
|
||||
|
||||
// detectSSHServerType detects SSH server type using NetBird network connection
|
||||
func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) {
|
||||
return sshdetection.DetectSSHServerType(ctx, client, host, port)
|
||||
}
|
||||
|
||||
// createClientObject wraps the NetBird client in a JavaScript object
|
||||
func createClientObject(client *netbird.Client) js.Value {
|
||||
obj := make(map[string]interface{})
|
||||
|
||||
Reference in New Issue
Block a user