mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
99
client/ssh/detection/detection.go
Normal file
99
client/ssh/detection/detection.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package detection
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// ServerIdentifier is the base response for NetBird SSH servers
|
||||
ServerIdentifier = "NetBird-SSH-Server"
|
||||
// ProxyIdentifier is the base response for NetBird SSH proxy
|
||||
ProxyIdentifier = "NetBird-SSH-Proxy"
|
||||
// JWTRequiredMarker is appended to responses when JWT is required
|
||||
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||
|
||||
// Timeout is the timeout for SSH server detection
|
||||
Timeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
|
||||
const (
|
||||
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
|
||||
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
|
||||
ServerTypeRegular ServerType = "regular"
|
||||
)
|
||||
|
||||
// Dialer provides network connection capabilities
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// RequiresJWT checks if the server type requires JWT authentication
|
||||
func (s ServerType) RequiresJWT() bool {
|
||||
return s == ServerTypeNetBirdJWT
|
||||
}
|
||||
|
||||
// ExitCode returns the exit code for the detect command
|
||||
func (s ServerType) ExitCode() int {
|
||||
switch s {
|
||||
case ServerTypeNetBirdJWT:
|
||||
return 0
|
||||
case ServerTypeNetBirdNoJWT:
|
||||
return 1
|
||||
case ServerTypeRegular:
|
||||
return 2
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
// DetectSSHServerType detects SSH server type using the provided dialer
|
||||
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
|
||||
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Debugf("SSH connection failed for detection: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||
log.Debugf("set read deadline: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
serverBanner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("read SSH banner: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
serverBanner = strings.TrimSpace(serverBanner)
|
||||
log.Debugf("SSH server banner: %s", serverBanner)
|
||||
|
||||
if !strings.HasPrefix(serverBanner, "SSH-") {
|
||||
log.Debugf("Invalid SSH banner")
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if !strings.Contains(serverBanner, ServerIdentifier) {
|
||||
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if strings.Contains(serverBanner, JWTRequiredMarker) {
|
||||
return ServerTypeNetBirdJWT, nil
|
||||
}
|
||||
|
||||
return ServerTypeNetBirdNoJWT, nil
|
||||
}
|
||||
Reference in New Issue
Block a user