package detection import ( "bufio" "context" "net" "strconv" "strings" "time" log "github.com/sirupsen/logrus" ) const ( // ServerIdentifier is the base response for NetBird SSH servers ServerIdentifier = "NetBird-SSH-Server" // ProxyIdentifier is the base response for NetBird SSH proxy ProxyIdentifier = "NetBird-SSH-Proxy" // JWTRequiredMarker is appended to responses when JWT is required JWTRequiredMarker = "NetBird-JWT-Required" // Timeout is the timeout for SSH server detection Timeout = 5 * time.Second ) type ServerType string const ( ServerTypeNetBirdJWT ServerType = "netbird-jwt" ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt" ServerTypeRegular ServerType = "regular" ) // Dialer provides network connection capabilities type Dialer interface { DialContext(ctx context.Context, network, address string) (net.Conn, error) } // RequiresJWT checks if the server type requires JWT authentication func (s ServerType) RequiresJWT() bool { return s == ServerTypeNetBirdJWT } // ExitCode returns the exit code for the detect command func (s ServerType) ExitCode() int { switch s { case ServerTypeNetBirdJWT: return 0 case ServerTypeNetBirdNoJWT: return 1 case ServerTypeRegular: return 2 default: return 2 } } // DetectSSHServerType detects SSH server type using the provided dialer func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) { targetAddr := net.JoinHostPort(host, strconv.Itoa(port)) conn, err := dialer.DialContext(ctx, "tcp", targetAddr) if err != nil { log.Debugf("SSH connection failed for detection: %v", err) return ServerTypeRegular, nil } defer conn.Close() if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil { log.Debugf("set read deadline: %v", err) return ServerTypeRegular, nil } reader := bufio.NewReader(conn) serverBanner, err := reader.ReadString('\n') if err != nil { log.Debugf("read SSH banner: %v", err) return ServerTypeRegular, nil } serverBanner = strings.TrimSpace(serverBanner) log.Debugf("SSH server banner: %s", serverBanner) if !strings.HasPrefix(serverBanner, "SSH-") { log.Debugf("Invalid SSH banner") return ServerTypeRegular, nil } if !strings.Contains(serverBanner, ServerIdentifier) { log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier) return ServerTypeRegular, nil } if strings.Contains(serverBanner, JWTRequiredMarker) { return ServerTypeNetBirdJWT, nil } return ServerTypeNetBirdNoJWT, nil }