mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Increase ssh detection timeout (#4827)
This commit is contained in:
@@ -343,10 +343,13 @@ func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientCo
|
||||
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
detectionCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(detectionCtx, dialer, host, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||
return nil, fmt.Errorf("SSH server detection: %w", err)
|
||||
}
|
||||
|
||||
if !serverType.RequiresJWT() {
|
||||
|
||||
@@ -189,12 +189,7 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
} else {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||
}
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
|
||||
@@ -3,6 +3,7 @@ package detection
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -19,8 +20,8 @@ const (
|
||||
// JWTRequiredMarker is appended to responses when JWT is required
|
||||
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||
|
||||
// Timeout is the timeout for SSH server detection
|
||||
Timeout = 5 * time.Second
|
||||
// DefaultTimeout is the default timeout for SSH server detection
|
||||
DefaultTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
@@ -61,21 +62,20 @@ func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port i
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Debugf("SSH connection failed for detection: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
return ServerTypeRegular, fmt.Errorf("connect to %s: %w", targetAddr, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||
log.Debugf("set read deadline: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
if deadline, ok := ctx.Deadline(); ok {
|
||||
if err := conn.SetReadDeadline(deadline); err != nil {
|
||||
return ServerTypeRegular, fmt.Errorf("set read deadline: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
serverBanner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("read SSH banner: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
return ServerTypeRegular, fmt.Errorf("read SSH banner: %w", err)
|
||||
}
|
||||
|
||||
serverBanner = strings.TrimSpace(serverBanner)
|
||||
|
||||
@@ -58,7 +58,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
if err != nil {
|
||||
t.Logf("Detection failed: %v", err)
|
||||
@@ -93,7 +93,7 @@ func TestJWTEnforcement(t *testing.T) {
|
||||
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||
@@ -218,7 +218,7 @@ func TestJWTDetection(t *testing.T) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||
|
||||
Reference in New Issue
Block a user