mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[client] Persist service install parameters across reinstalls (#5732)
This commit is contained in:
@@ -41,7 +41,7 @@ func init() {
|
|||||||
defaultServiceName = "Netbird"
|
defaultServiceName = "Netbird"
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||||
|
|
||||||
|
|||||||
@@ -119,6 +119,10 @@ var installCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
svcConfig, err := createServiceConfigForInstall()
|
svcConfig, err := createServiceConfigForInstall()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -136,6 +140,10 @@ var installCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("install service: %w", err)
|
return fmt.Errorf("install service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("NetBird service has been installed")
|
cmd.Println("NetBird service has been installed")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -187,6 +195,10 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
wasRunning, err := isServiceRunning()
|
wasRunning, err := isServiceRunning()
|
||||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||||
return fmt.Errorf("check service status: %w", err)
|
return fmt.Errorf("check service status: %w", err)
|
||||||
@@ -222,6 +234,10 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
return fmt.Errorf("install service with new config: %w", err)
|
return fmt.Errorf("install service with new config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
if wasRunning {
|
if wasRunning {
|
||||||
cmd.Println("Starting NetBird service...")
|
cmd.Println("Starting NetBird service...")
|
||||||
if err := s.Start(); err != nil {
|
if err := s.Start(); err != nil {
|
||||||
|
|||||||
201
client/cmd/service_params.go
Normal file
201
client/cmd/service_params.go
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const serviceParamsFile = "service.json"
|
||||||
|
|
||||||
|
// serviceParams holds install-time service parameters that persist across
|
||||||
|
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
|
||||||
|
type serviceParams struct {
|
||||||
|
LogLevel string `json:"log_level"`
|
||||||
|
DaemonAddr string `json:"daemon_addr"`
|
||||||
|
ManagementURL string `json:"management_url,omitempty"`
|
||||||
|
ConfigPath string `json:"config_path,omitempty"`
|
||||||
|
LogFiles []string `json:"log_files,omitempty"`
|
||||||
|
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||||
|
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||||
|
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// serviceParamsPath returns the path to the service params file.
|
||||||
|
func serviceParamsPath() string {
|
||||||
|
return filepath.Join(configs.StateDir, serviceParamsFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadServiceParams reads saved service parameters from disk.
|
||||||
|
// Returns nil with no error if the file does not exist.
|
||||||
|
func loadServiceParams() (*serviceParams, error) {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("read service params %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var params serviceParams
|
||||||
|
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse service params %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ¶ms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveServiceParams writes current service parameters to disk atomically
|
||||||
|
// with restricted permissions.
|
||||||
|
func saveServiceParams(params *serviceParams) error {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
|
||||||
|
return fmt.Errorf("save service params: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// currentServiceParams captures the current state of all package-level
|
||||||
|
// variables into a serviceParams struct.
|
||||||
|
func currentServiceParams() *serviceParams {
|
||||||
|
params := &serviceParams{
|
||||||
|
LogLevel: logLevel,
|
||||||
|
DaemonAddr: daemonAddr,
|
||||||
|
ManagementURL: managementURL,
|
||||||
|
ConfigPath: configPath,
|
||||||
|
LogFiles: logFiles,
|
||||||
|
DisableProfiles: profilesDisabled,
|
||||||
|
DisableUpdateSettings: updateSettingsDisabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(serviceEnvVars) > 0 {
|
||||||
|
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
if err == nil && len(parsed) > 0 {
|
||||||
|
params.ServiceEnvVars = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAndApplyServiceParams loads saved params from disk and applies them
|
||||||
|
// to any flags that were not explicitly set.
|
||||||
|
func loadAndApplyServiceParams(cmd *cobra.Command) error {
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
applyServiceParams(cmd, params)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyServiceParams merges saved parameters into package-level variables
|
||||||
|
// for any flag that was not explicitly set by the user (via CLI or env var).
|
||||||
|
// Flags that were Changed() are left untouched.
|
||||||
|
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
|
if params == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For fields with non-empty defaults (log-level, daemon-addr), keep the
|
||||||
|
// != "" guard so that an older service.json missing the field doesn't
|
||||||
|
// clobber the default with an empty string.
|
||||||
|
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
||||||
|
logLevel = params.LogLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
|
||||||
|
daemonAddr = params.DaemonAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// For optional fields where empty means "use default", always apply so
|
||||||
|
// that an explicit clear (--management-url "") persists across reinstalls.
|
||||||
|
if !rootCmd.PersistentFlags().Changed("management-url") {
|
||||||
|
managementURL = params.ManagementURL
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("config") {
|
||||||
|
configPath = params.ConfigPath
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("log-file") {
|
||||||
|
logFiles = params.LogFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
|
||||||
|
profilesDisabled = params.DisableProfiles
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
|
||||||
|
updateSettingsDisabled = params.DisableUpdateSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyServiceEnvParams merges saved service environment variables.
|
||||||
|
// If --service-env was explicitly set, explicit values win on key conflict
|
||||||
|
// but saved keys not in the explicit set are carried over.
|
||||||
|
// If --service-env was not set, saved env vars are used entirely.
|
||||||
|
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
|
if len(params.ServiceEnvVars) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cmd.Flags().Changed("service-env") {
|
||||||
|
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||||
|
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit env vars were provided: merge saved values underneath.
|
||||||
|
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||||
|
maps.Copy(merged, params.ServiceEnvVars)
|
||||||
|
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||||
|
serviceEnvVars = envMapToSlice(merged)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resetParamsCmd = &cobra.Command{
|
||||||
|
Use: "reset-params",
|
||||||
|
Short: "Remove saved service install parameters",
|
||||||
|
Long: "Removes the saved service.json file so the next install uses default parameters.",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
cmd.Println("No saved service parameters found")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("remove service params: %w", err)
|
||||||
|
}
|
||||||
|
cmd.Printf("Removed saved service parameters (%s)\n", path)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
|
||||||
|
func envMapToSlice(m map[string]string) []string {
|
||||||
|
s := make([]string, 0, len(m))
|
||||||
|
for k, v := range m {
|
||||||
|
s = append(s, k+"="+v)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
523
client/cmd/service_params_test.go
Normal file
523
client/cmd/service_params_test.go
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"go/ast"
|
||||||
|
"go/parser"
|
||||||
|
"go/token"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServiceParamsPath(t *testing.T) {
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
|
||||||
|
configs.StateDir = "/var/lib/netbird"
|
||||||
|
assert.Equal(t, "/var/lib/netbird/service.json", serviceParamsPath())
|
||||||
|
|
||||||
|
configs.StateDir = "/custom/state"
|
||||||
|
assert.Equal(t, "/custom/state/service.json", serviceParamsPath())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
params := &serviceParams{
|
||||||
|
LogLevel: "debug",
|
||||||
|
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||||
|
ManagementURL: "https://my.server.com",
|
||||||
|
ConfigPath: "/etc/netbird/config.json",
|
||||||
|
LogFiles: []string{"/var/log/netbird/client.log", "console"},
|
||||||
|
DisableProfiles: true,
|
||||||
|
DisableUpdateSettings: false,
|
||||||
|
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := saveServiceParams(params)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the file exists and is valid JSON.
|
||||||
|
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, json.Valid(data))
|
||||||
|
|
||||||
|
loaded, err := loadServiceParams()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, loaded)
|
||||||
|
|
||||||
|
assert.Equal(t, params.LogLevel, loaded.LogLevel)
|
||||||
|
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
|
||||||
|
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
|
||||||
|
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
|
||||||
|
assert.Equal(t, params.LogFiles, loaded.LogFiles)
|
||||||
|
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
|
||||||
|
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
|
||||||
|
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadServiceParams_FileNotExists(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCurrentServiceParams(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
origDaemonAddr := daemonAddr
|
||||||
|
origManagementURL := managementURL
|
||||||
|
origConfigPath := configPath
|
||||||
|
origLogFiles := logFiles
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logLevel = origLogLevel
|
||||||
|
daemonAddr = origDaemonAddr
|
||||||
|
managementURL = origManagementURL
|
||||||
|
configPath = origConfigPath
|
||||||
|
logFiles = origLogFiles
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
serviceEnvVars = origServiceEnvVars
|
||||||
|
})
|
||||||
|
|
||||||
|
logLevel = "trace"
|
||||||
|
daemonAddr = "tcp://127.0.0.1:9999"
|
||||||
|
managementURL = "https://mgmt.example.com"
|
||||||
|
configPath = "/tmp/test-config.json"
|
||||||
|
logFiles = []string{"/tmp/test.log"}
|
||||||
|
profilesDisabled = true
|
||||||
|
updateSettingsDisabled = true
|
||||||
|
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
|
||||||
|
|
||||||
|
params := currentServiceParams()
|
||||||
|
|
||||||
|
assert.Equal(t, "trace", params.LogLevel)
|
||||||
|
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
|
||||||
|
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
|
||||||
|
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
|
||||||
|
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
|
||||||
|
assert.True(t, params.DisableProfiles)
|
||||||
|
assert.True(t, params.DisableUpdateSettings)
|
||||||
|
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
origDaemonAddr := daemonAddr
|
||||||
|
origManagementURL := managementURL
|
||||||
|
origConfigPath := configPath
|
||||||
|
origLogFiles := logFiles
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logLevel = origLogLevel
|
||||||
|
daemonAddr = origDaemonAddr
|
||||||
|
managementURL = origManagementURL
|
||||||
|
configPath = origConfigPath
|
||||||
|
logFiles = origLogFiles
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
serviceEnvVars = origServiceEnvVars
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reset all flags to defaults.
|
||||||
|
logLevel = "info"
|
||||||
|
daemonAddr = "unix:///var/run/netbird.sock"
|
||||||
|
managementURL = ""
|
||||||
|
configPath = "/etc/netbird/config.json"
|
||||||
|
logFiles = []string{"/var/log/netbird/client.log"}
|
||||||
|
profilesDisabled = false
|
||||||
|
updateSettingsDisabled = false
|
||||||
|
serviceEnvVars = nil
|
||||||
|
|
||||||
|
// Reset Changed state on all relevant flags.
|
||||||
|
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate user explicitly setting --log-level via CLI.
|
||||||
|
logLevel = "warn"
|
||||||
|
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
LogLevel: "debug",
|
||||||
|
DaemonAddr: "tcp://127.0.0.1:5555",
|
||||||
|
ManagementURL: "https://saved.example.com",
|
||||||
|
ConfigPath: "/saved/config.json",
|
||||||
|
LogFiles: []string{"/saved/client.log"},
|
||||||
|
DisableProfiles: true,
|
||||||
|
DisableUpdateSettings: true,
|
||||||
|
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
// log-level was Changed, so it should keep "warn", not use saved "debug".
|
||||||
|
assert.Equal(t, "warn", logLevel)
|
||||||
|
|
||||||
|
// All other fields were not Changed, so they should use saved values.
|
||||||
|
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
|
||||||
|
assert.Equal(t, "https://saved.example.com", managementURL)
|
||||||
|
assert.Equal(t, "/saved/config.json", configPath)
|
||||||
|
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
|
||||||
|
assert.True(t, profilesDisabled)
|
||||||
|
assert.True(t, updateSettingsDisabled)
|
||||||
|
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
t.Cleanup(func() {
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate current state where booleans are true (e.g. set by previous install).
|
||||||
|
profilesDisabled = true
|
||||||
|
updateSettingsDisabled = true
|
||||||
|
|
||||||
|
// Reset Changed state so flags appear unset.
|
||||||
|
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saved params have both as false.
|
||||||
|
saved := &serviceParams{
|
||||||
|
DisableProfiles: false,
|
||||||
|
DisableUpdateSettings: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.False(t, profilesDisabled, "saved false should override current true")
|
||||||
|
assert.False(t, updateSettingsDisabled, "saved false should override current true")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
|
||||||
|
origManagementURL := managementURL
|
||||||
|
t.Cleanup(func() { managementURL = origManagementURL })
|
||||||
|
|
||||||
|
managementURL = "https://leftover.example.com"
|
||||||
|
|
||||||
|
// Simulate saved params where management URL was explicitly cleared.
|
||||||
|
saved := &serviceParams{
|
||||||
|
LogLevel: "info",
|
||||||
|
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||||
|
// ManagementURL intentionally empty: was cleared with --management-url "".
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_NilParams(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
t.Cleanup(func() { logLevel = origLogLevel })
|
||||||
|
|
||||||
|
logLevel = "info"
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
|
||||||
|
// Should be a no-op.
|
||||||
|
applyServiceParams(cmd, nil)
|
||||||
|
assert.Equal(t, "info", logLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
// Set up a command with --service-env marked as Changed.
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
|
||||||
|
|
||||||
|
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{
|
||||||
|
"SAVED": "val",
|
||||||
|
"OVERLAP": "saved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
// Parse result for easier assertion.
|
||||||
|
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "yes", result["EXPLICIT"])
|
||||||
|
assert.Equal(t, "val", result["SAVED"])
|
||||||
|
// Explicit wins on conflict.
|
||||||
|
assert.Equal(t, "explicit", result["OVERLAP"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
serviceEnvVars = nil
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||||
|
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||||
|
// added to serviceParams but not wired into these functions, this test fails.
|
||||||
|
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Collect all JSON field names from the serviceParams struct.
|
||||||
|
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||||
|
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
|
||||||
|
|
||||||
|
// Collect field names referenced in currentServiceParams and applyServiceParams.
|
||||||
|
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
|
||||||
|
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
|
||||||
|
// applyServiceEnvParams handles ServiceEnvVars indirectly.
|
||||||
|
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
|
||||||
|
for k, v := range applyEnvFields {
|
||||||
|
applyFields[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range structFields {
|
||||||
|
assert.Contains(t, currentFields, field,
|
||||||
|
"serviceParams field %q is not captured in currentServiceParams()", field)
|
||||||
|
assert.Contains(t, applyFields, field,
|
||||||
|
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
|
||||||
|
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
|
||||||
|
// it flows through newSVCConfig() EnvVars, not CLI args.
|
||||||
|
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||||
|
require.NotEmpty(t, structFields)
|
||||||
|
|
||||||
|
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
|
||||||
|
fieldsNotInArgs := map[string]bool{
|
||||||
|
"ServiceEnvVars": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
|
||||||
|
|
||||||
|
// Forward: every struct field must appear in buildServiceArguments.
|
||||||
|
for _, field := range structFields {
|
||||||
|
if fieldsNotInArgs[field] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
globalVar := fieldToGlobalVar(field)
|
||||||
|
assert.Contains(t, buildFields, globalVar,
|
||||||
|
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse: every service-related global used in buildServiceArguments must
|
||||||
|
// have a corresponding serviceParams field. This catches a developer adding
|
||||||
|
// a new flag to buildServiceArguments without adding it to the struct.
|
||||||
|
globalToField := make(map[string]string, len(structFields))
|
||||||
|
for _, field := range structFields {
|
||||||
|
globalToField[fieldToGlobalVar(field)] = field
|
||||||
|
}
|
||||||
|
// Identifiers in buildServiceArguments that are not service params
|
||||||
|
// (builtins, boilerplate, loop variables).
|
||||||
|
nonParamGlobals := map[string]bool{
|
||||||
|
"args": true, "append": true, "string": true, "_": true,
|
||||||
|
"logFile": true, // range variable over logFiles
|
||||||
|
}
|
||||||
|
for ref := range buildFields {
|
||||||
|
if nonParamGlobals[ref] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, inStruct := globalToField[ref]
|
||||||
|
assert.True(t, inStruct,
|
||||||
|
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractStructJSONFields returns field names from a named struct type.
|
||||||
|
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
|
||||||
|
t.Helper()
|
||||||
|
var fields []string
|
||||||
|
ast.Inspect(file, func(n ast.Node) bool {
|
||||||
|
ts, ok := n.(*ast.TypeSpec)
|
||||||
|
if !ok || ts.Name.Name != structName {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
st, ok := ts.Type.(*ast.StructType)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, f := range st.Fields.List {
|
||||||
|
if len(f.Names) > 0 {
|
||||||
|
fields = append(fields, f.Names[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFuncFieldRefs returns which of the given field names appear inside the
|
||||||
|
// named function, either as selector expressions (params.FieldName) or as
|
||||||
|
// composite literal keys (&serviceParams{FieldName: ...}).
|
||||||
|
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
fieldSet := make(map[string]bool, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
fieldSet[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]bool)
|
||||||
|
fn := findFuncDecl(file, funcName)
|
||||||
|
require.NotNil(t, fn, "function %s not found", funcName)
|
||||||
|
|
||||||
|
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||||
|
switch v := n.(type) {
|
||||||
|
case *ast.SelectorExpr:
|
||||||
|
if fieldSet[v.Sel.Name] {
|
||||||
|
found[v.Sel.Name] = true
|
||||||
|
}
|
||||||
|
case *ast.KeyValueExpr:
|
||||||
|
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
|
||||||
|
found[ident.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
|
||||||
|
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
fn := findFuncDecl(file, funcName)
|
||||||
|
require.NotNil(t, fn, "function %s not found", funcName)
|
||||||
|
|
||||||
|
refs := make(map[string]bool)
|
||||||
|
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||||
|
if ident, ok := n.(*ast.Ident); ok {
|
||||||
|
refs[ident.Name] = true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return refs
|
||||||
|
}
|
||||||
|
|
||||||
|
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
|
||||||
|
for _, decl := range file.Decls {
|
||||||
|
fn, ok := decl.(*ast.FuncDecl)
|
||||||
|
if ok && fn.Name.Name == name {
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldToGlobalVar maps serviceParams field names to the package-level variable
|
||||||
|
// names used in buildServiceArguments and applyServiceParams.
|
||||||
|
func fieldToGlobalVar(field string) string {
|
||||||
|
m := map[string]string{
|
||||||
|
"LogLevel": "logLevel",
|
||||||
|
"DaemonAddr": "daemonAddr",
|
||||||
|
"ManagementURL": "managementURL",
|
||||||
|
"ConfigPath": "configPath",
|
||||||
|
"LogFiles": "logFiles",
|
||||||
|
"DisableProfiles": "profilesDisabled",
|
||||||
|
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||||
|
"ServiceEnvVars": "serviceEnvVars",
|
||||||
|
}
|
||||||
|
if v, ok := m[field]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
// Default: lowercase first letter.
|
||||||
|
return strings.ToLower(field[:1]) + field[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMapToSlice(t *testing.T) {
|
||||||
|
m := map[string]string{"A": "1", "B": "2"}
|
||||||
|
s := envMapToSlice(m)
|
||||||
|
assert.Len(t, s, 2)
|
||||||
|
assert.Contains(t, s, "A=1")
|
||||||
|
assert.Contains(t, s, "B=2")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMapToSlice_Empty(t *testing.T) {
|
||||||
|
s := envMapToSlice(map[string]string{})
|
||||||
|
assert.Empty(t, s)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user