mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-26 20:26:39 +00:00
Merge branch 'main' into ssh-rewrite
This commit is contained in:
@@ -106,7 +106,6 @@ Examples:
|
||||
}
|
||||
|
||||
func sshFn(cmd *cobra.Command, args []string) error {
|
||||
// Check if help was requested
|
||||
for _, arg := range args {
|
||||
if arg == "-h" || arg == "--help" {
|
||||
return cmd.Help()
|
||||
@@ -116,14 +115,11 @@ func sshFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
|
||||
// Global flags were already parsed by validateSSHArgsWithoutFlagParsing
|
||||
// No additional parsing needed here
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
logOutput := "console"
|
||||
if logFile != "" && logFile != "/var/log/netbird/client.log" {
|
||||
logOutput = logFile
|
||||
if firstLogFile := util.FindFirstLogPath(logFiles); firstLogFile != "" && firstLogFile != "/var/log/netbird/client.log" {
|
||||
logOutput = firstLogFile
|
||||
}
|
||||
if err := util.InitLog(logLevel, logOutput); err != nil {
|
||||
return fmt.Errorf("init log: %w", err)
|
||||
@@ -233,7 +229,6 @@ func findSSHCommandPosition(args []string) int {
|
||||
const (
|
||||
configFlag = "config"
|
||||
logLevelFlag = "log-level"
|
||||
logFileFlag = "log-file"
|
||||
)
|
||||
|
||||
// parseGlobalArgs processes the global arguments and sets the corresponding variables
|
||||
@@ -241,7 +236,6 @@ func parseGlobalArgs(globalArgs []string) {
|
||||
flagHandlers := map[string]func(string){
|
||||
configFlag: func(value string) { configPath = value },
|
||||
logLevelFlag: func(value string) { logLevel = value },
|
||||
logFileFlag: func(value string) { logFile = value },
|
||||
}
|
||||
|
||||
shortFlags := map[string]string{
|
||||
@@ -334,10 +328,9 @@ func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers
|
||||
}
|
||||
|
||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string, *string) {
|
||||
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
defaultLogFile := getEnvOrDefault("LOG_FILE", logFile)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
@@ -358,9 +351,8 @@ func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string,
|
||||
fs.String("config", defaultConfigPath, "Netbird config file location")
|
||||
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
||||
logFileFlag := fs.String("log-file", defaultLogFile, "sets Netbird log path")
|
||||
|
||||
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag, logFileFlag
|
||||
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag
|
||||
}
|
||||
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
@@ -370,14 +362,13 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
|
||||
resetSSHGlobals()
|
||||
|
||||
// Extract global flags that were passed before 'ssh' by checking original command line
|
||||
if len(os.Args) > 2 {
|
||||
extractGlobalFlags(os.Args[1:])
|
||||
}
|
||||
|
||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||
|
||||
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag, logFileFlag := createSSHFlagSet()
|
||||
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
return parseHostnameAndCommand(filteredArgs)
|
||||
@@ -388,7 +379,6 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
// Set parsed values
|
||||
port = *portFlag
|
||||
if *userFlag != "" {
|
||||
username = *userFlag
|
||||
@@ -400,17 +390,12 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
knownHostsFile = *knownHostsFlag
|
||||
identityFile = *identityFlag
|
||||
|
||||
// Global flags were already extracted in extractGlobalFlags()
|
||||
// Only override with SSH-specific flags if they were explicitly provided
|
||||
if *configFlag != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = *configFlag
|
||||
}
|
||||
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = *logLevelFlag
|
||||
}
|
||||
if *logFileFlag != getEnvOrDefault("LOG_FILE", logFile) {
|
||||
logFile = *logFileFlag
|
||||
}
|
||||
|
||||
localForwards = localForwardFlags
|
||||
remoteForwards = remoteForwardFlags
|
||||
@@ -423,14 +408,12 @@ func parseHostnameAndCommand(args []string) error {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
// Parse hostname (possibly with user@host format)
|
||||
arg := args[0]
|
||||
if strings.Contains(arg, "@") {
|
||||
parts := strings.SplitN(arg, "@", 2)
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.New("invalid user@host format")
|
||||
}
|
||||
// Only use username from host if not already set by flags
|
||||
if username == "" {
|
||||
username = parts[0]
|
||||
}
|
||||
@@ -439,7 +422,6 @@ func parseHostnameAndCommand(args []string) error {
|
||||
host = arg
|
||||
}
|
||||
|
||||
// Set default username if none provided
|
||||
if username == "" {
|
||||
if sudoUser := os.Getenv("SUDO_USER"); sudoUser != "" {
|
||||
username = sudoUser
|
||||
|
||||
Reference in New Issue
Block a user