Files
netbird/client/internal/profilemanager/config.go
Zoltan Papp 0efef671d7 [client] Unexport GetServerPublicKey, add HealthCheck method (#5735)
* Unexport GetServerPublicKey, add HealthCheck method

Internalize server key fetching into Login, Register,
GetDeviceAuthorizationFlow, and GetPKCEAuthorizationFlow methods,
removing the need for callers to fetch and pass the key separately.

Replace the exported GetServerPublicKey with a HealthCheck() error
method for connection validation, keeping IsHealthy() bool for
non-blocking background monitoring.

Fix test encryption to use correct key pairs (client public key as
remotePubKey instead of server private key).

* Refactor `doMgmLogin` to return only error, removing unused response
2026-04-07 12:18:21 +02:00

939 lines
29 KiB
Go

package profilemanager
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"net/url"
"os"
"os/user"
"path/filepath"
"reflect"
"runtime"
"slices"
"strings"
"time"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/util"
)
const (
// managementLegacyPortString is the port that was used before by the Management gRPC server.
// It is used for backward compatibility now.
managementLegacyPortString = "33073"
// DefaultManagementURL points to the NetBird's cloud management endpoint
DefaultManagementURL = "https://api.netbird.io:443"
// oldDefaultManagementURL points to the NetBird's old cloud management endpoint
oldDefaultManagementURL = "https://api.wiretrustee.com:443"
// DefaultAdminURL points to NetBird's cloud management console
DefaultAdminURL = "https://app.netbird.io:443"
)
// mgmProber is the subset of management client needed for URL migration probes.
type mgmProber interface {
GetServerPublicKey() (*wgtypes.Key, error)
Close() error
}
// newMgmProber creates a management client for probing URL reachability.
// Overridden in tests to avoid real network calls.
var newMgmProber = func(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (mgmProber, error) {
return mgm.NewClient(ctx, addr, key, tlsEnabled)
}
var DefaultInterfaceBlacklist = []string{
iface.WgInterfaceDefault, "wt", "utun", "tun0", "zt", "ZeroTier", "wg", "ts",
"Tailscale", "tailscale", "docker", "veth", "br-", "lo",
}
// ConfigInput carries configuration changes to the client
type ConfigInput struct {
ManagementURL string
AdminURL string
ConfigPath string
StateFilePath string
PreSharedKey *string
ServerSSHAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
RosenpassPermissive *bool
InterfaceName *string
WireguardPort *int
NetworkMonitor *bool
DisableAutoConnect *bool
ExtraIFaceBlackList []string
DNSRouteInterval *time.Duration
ClientCertPath string
ClientCertKeyPath string
DisableClientRoutes *bool
DisableServerRoutes *bool
DisableDNS *bool
DisableFirewall *bool
BlockLANAccess *bool
BlockInbound *bool
DisableNotifications *bool
DNSLabels domain.List
LazyConnectionEnabled *bool
MTU *uint16
}
// Config Configuration type
type Config struct {
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string
ManagementURL *url.URL
AdminURL *url.URL
WgIface string
WgPort int
NetworkMonitor *bool
IFaceBlackList []string
DisableIPv6Discovery bool
RosenpassEnabled bool
RosenpassPermissive bool
ServerSSHAllowed *bool
EnableSSHRoot *bool
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
SSHJWTCacheTTL *int
DisableClientRoutes bool
DisableServerRoutes bool
DisableDNS bool
DisableFirewall bool
BlockLANAccess bool
BlockInbound bool
DisableNotifications *bool
DNSLabels domain.List
// SSHKey is a private SSH key in a PEM format
SSHKey string
// ExternalIP mappings, if different from the host interface IP
//
// External IP must not be behind a CGNAT and port-forwarding for incoming UDP packets from WgPort on ExternalIP
// to WgPort on host interface IP must be present. This can take form of single port-forwarding rule, 1:1 DNAT
// mapping ExternalIP to host interface IP, or a NAT DMZ to host interface IP.
//
// A single mapping will take the form of: external[/internal]
// external (required): either the external IP address or "stun" to use STUN to determine the external IP address
// internal (optional): either the internal/interface IP address or an interface name
//
// examples:
// "12.34.56.78" => all interfaces IPs will be mapped to external IP of 12.34.56.78
// "12.34.56.78/eth0" => IPv4 assigned to interface eth0 will be mapped to external IP of 12.34.56.78
// "12.34.56.78/10.1.2.3" => interface IP 10.1.2.3 will be mapped to external IP of 12.34.56.78
NATExternalIPs []string
// CustomDNSAddress sets the DNS resolver listening address in format ip:port
CustomDNSAddress string
// DisableAutoConnect determines whether the client should not start with the service
// it's set to false by default due to backwards compatibility
DisableAutoConnect bool
// DNSRouteInterval is the interval in which the DNS routes are updated
DNSRouteInterval time.Duration
// Path to a certificate used for mTLS authentication
ClientCertPath string
// Path to corresponding private key of ClientCertPath
ClientCertKeyPath string
ClientCertKeyPair *tls.Certificate `json:"-"`
LazyConnectionEnabled bool
MTU uint16
}
var ConfigDirOverride string
func getConfigDir() (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
base, err := baseConfigDir()
if err != nil {
return "", err
}
configDir := filepath.Join(base, "netbird")
if err := os.MkdirAll(configDir, 0o755); err != nil {
return "", err
}
return configDir, nil
}
func baseConfigDir() (string, error) {
if runtime.GOOS == "darwin" {
if u, err := user.Current(); err == nil && u.HomeDir != "" {
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
}
}
return os.UserConfigDir()
}
func getConfigDirForUser(username string) (string, error) {
if ConfigDirOverride != "" {
return ConfigDirOverride, nil
}
username = sanitizeProfileName(username)
configDir := filepath.Join(DefaultConfigPathDir, username)
if _, err := os.Stat(configDir); os.IsNotExist(err) {
if err := os.MkdirAll(configDir, 0700); err != nil {
return "", err
}
}
return configDir, nil
}
func fileExists(path string) (bool, error) {
_, err := os.Stat(path)
if err == nil {
return true, nil
}
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
// createNewConfig creates a new config generating a new Wireguard key and saving to file
func createNewConfig(input ConfigInput) (*Config, error) {
config := &Config{
// defaults to false only for new (post 0.26) configurations
ServerSSHAllowed: util.False(),
WgPort: iface.DefaultWgPort,
}
if _, err := config.apply(input); err != nil {
return nil, err
}
return config, nil
}
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() {
log.Infof("new Management URL provided, updated to %#v (old value %#v)",
input.ManagementURL, config.ManagementURL.String())
URL, err := parseURL("Management URL", input.ManagementURL)
if err != nil {
return false, err
}
config.ManagementURL = URL
updated = true
} else if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)
if err != nil {
return false, err
}
}
if config.AdminURL == nil {
log.Infof("using default Admin URL %s", DefaultAdminURL)
config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL)
if err != nil {
return false, err
}
}
if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() {
log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)",
input.AdminURL, config.AdminURL.String())
newURL, err := parseURL("Admin Panel URL", input.AdminURL)
if err != nil {
return updated, err
}
config.AdminURL = newURL
updated = true
}
if config.PrivateKey == "" {
log.Infof("generated new Wireguard key")
config.PrivateKey = generateKey()
updated = true
}
if config.SSHKey == "" {
log.Infof("generated new SSH key")
pem, err := ssh.GeneratePrivateKey(ssh.ED25519)
if err != nil {
return false, err
}
config.SSHKey = string(pem)
updated = true
}
if input.WireguardPort != nil && *input.WireguardPort != config.WgPort {
log.Infof("updating Wireguard port %d (old value %d)",
*input.WireguardPort, config.WgPort)
config.WgPort = *input.WireguardPort
updated = true
}
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
log.Infof("updating Wireguard interface %#v (old value %#v)",
*input.InterfaceName, config.WgIface)
config.WgIface = *input.InterfaceName
updated = true
} else if config.WgIface == "" {
config.WgIface = iface.WgInterfaceDefault
log.Infof("using default Wireguard interface %s", config.WgIface)
updated = true
}
if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) {
log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])",
strings.Join(input.NATExternalIPs, " "),
strings.Join(config.NATExternalIPs, " "))
config.NATExternalIPs = input.NATExternalIPs
updated = true
}
if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey {
log.Infof("new pre-shared key provided, replacing old key")
config.PreSharedKey = *input.PreSharedKey
updated = true
}
if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled {
log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled)
config.RosenpassEnabled = *input.RosenpassEnabled
updated = true
}
if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive {
log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive)
config.RosenpassPermissive = *input.RosenpassPermissive
updated = true
}
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
config.NetworkMonitor = input.NetworkMonitor
updated = true
}
if config.NetworkMonitor == nil {
// enable network monitoring by default on windows and darwin clients
if runtime.GOOS == "windows" || runtime.GOOS == "darwin" {
enabled := true
config.NetworkMonitor = &enabled
updated = true
}
}
if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress {
log.Infof("updating custom DNS address %#v (old value %#v)",
string(input.CustomDNSAddress), config.CustomDNSAddress)
config.CustomDNSAddress = string(input.CustomDNSAddress)
updated = true
}
if len(config.IFaceBlackList) == 0 {
log.Infof("filling in interface blacklist with defaults: [ %s ]",
strings.Join(DefaultInterfaceBlacklist, " "))
config.IFaceBlackList = append(config.IFaceBlackList, DefaultInterfaceBlacklist...)
updated = true
}
if len(input.ExtraIFaceBlackList) > 0 {
for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) {
log.Infof("adding new entry to interface blacklist: %s", iFace)
config.IFaceBlackList = append(config.IFaceBlackList, iFace)
updated = true
}
}
if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect {
if *input.DisableAutoConnect {
log.Infof("turning off automatic connection on startup")
} else {
log.Infof("enabling automatic connection on startup")
}
config.DisableAutoConnect = *input.DisableAutoConnect
updated = true
}
if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed {
if *input.ServerSSHAllowed {
log.Infof("enabling SSH server")
} else {
log.Infof("disabling SSH server")
}
config.ServerSSHAllowed = input.ServerSSHAllowed
updated = true
} else if config.ServerSSHAllowed == nil {
if runtime.GOOS == "android" {
// default to disabled SSH on Android for security
log.Infof("setting SSH server to false by default on Android")
config.ServerSSHAllowed = util.False()
} else {
// enables SSH for configs from old versions to preserve backwards compatibility
log.Infof("falling back to enabled SSH server for pre-existing configuration")
config.ServerSSHAllowed = util.True()
}
updated = true
}
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
if *input.EnableSSHRoot {
log.Infof("enabling SSH root login")
} else {
log.Infof("disabling SSH root login")
}
config.EnableSSHRoot = input.EnableSSHRoot
updated = true
}
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
if *input.EnableSSHSFTP {
log.Infof("enabling SSH SFTP subsystem")
} else {
log.Infof("disabling SSH SFTP subsystem")
}
config.EnableSSHSFTP = input.EnableSSHSFTP
updated = true
}
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
if *input.EnableSSHLocalPortForwarding {
log.Infof("enabling SSH local port forwarding")
} else {
log.Infof("disabling SSH local port forwarding")
}
config.EnableSSHLocalPortForwarding = input.EnableSSHLocalPortForwarding
updated = true
}
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
if *input.EnableSSHRemotePortForwarding {
log.Infof("enabling SSH remote port forwarding")
} else {
log.Infof("disabling SSH remote port forwarding")
}
config.EnableSSHRemotePortForwarding = input.EnableSSHRemotePortForwarding
updated = true
}
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
if *input.DisableSSHAuth {
log.Infof("disabling SSH authentication")
} else {
log.Infof("enabling SSH authentication")
}
config.DisableSSHAuth = input.DisableSSHAuth
updated = true
}
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
config.DNSRouteInterval = *input.DNSRouteInterval
updated = true
} else if config.DNSRouteInterval == 0 {
config.DNSRouteInterval = dynamic.DefaultInterval
log.Infof("using default DNS route interval %s", config.DNSRouteInterval)
updated = true
}
if input.DisableClientRoutes != nil && *input.DisableClientRoutes != config.DisableClientRoutes {
if *input.DisableClientRoutes {
log.Infof("disabling client routes")
} else {
log.Infof("enabling client routes")
}
config.DisableClientRoutes = *input.DisableClientRoutes
updated = true
}
if input.DisableServerRoutes != nil && *input.DisableServerRoutes != config.DisableServerRoutes {
if *input.DisableServerRoutes {
log.Infof("disabling server routes")
} else {
log.Infof("enabling server routes")
}
config.DisableServerRoutes = *input.DisableServerRoutes
updated = true
}
if input.DisableDNS != nil && *input.DisableDNS != config.DisableDNS {
if *input.DisableDNS {
log.Infof("disabling DNS configuration")
} else {
log.Infof("enabling DNS configuration")
}
config.DisableDNS = *input.DisableDNS
updated = true
}
if input.DisableFirewall != nil && *input.DisableFirewall != config.DisableFirewall {
if *input.DisableFirewall {
log.Infof("disabling firewall configuration")
} else {
log.Infof("enabling firewall configuration")
}
config.DisableFirewall = *input.DisableFirewall
updated = true
}
if input.BlockLANAccess != nil && *input.BlockLANAccess != config.BlockLANAccess {
if *input.BlockLANAccess {
log.Infof("blocking LAN access")
} else {
log.Infof("allowing LAN access")
}
config.BlockLANAccess = *input.BlockLANAccess
updated = true
}
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
if *input.BlockInbound {
log.Infof("blocking inbound connections")
} else {
log.Infof("allowing inbound connections")
}
config.BlockInbound = *input.BlockInbound
updated = true
}
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
if *input.DisableNotifications {
log.Infof("disabling notifications")
} else {
log.Infof("enabling notifications")
}
config.DisableNotifications = input.DisableNotifications
updated = true
}
if config.DisableNotifications == nil {
disabled := true
config.DisableNotifications = &disabled
log.Infof("setting notifications to disabled by default")
updated = true
}
if input.ClientCertKeyPath != "" {
config.ClientCertKeyPath = input.ClientCertKeyPath
updated = true
}
if input.ClientCertPath != "" {
config.ClientCertPath = input.ClientCertPath
updated = true
}
if config.ClientCertPath != "" && config.ClientCertKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.ClientCertPath, config.ClientCertKeyPath)
if err != nil {
log.Error("Failed to load mTLS cert/key pair: ", err)
} else {
config.ClientCertKeyPair = &cert
log.Info("Loaded client mTLS cert/key pair")
}
}
if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) {
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
input.DNSLabels.SafeString(),
config.DNSLabels.SafeString())
config.DNSLabels = input.DNSLabels
updated = true
}
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
updated = true
}
if input.MTU != nil && *input.MTU != config.MTU {
log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU)
config.MTU = *input.MTU
updated = true
} else if config.MTU == 0 {
config.MTU = iface.DefaultMTU
log.Infof("using default MTU %d", config.MTU)
updated = true
}
return updated, nil
}
// parseURL parses and validates a service URL
func parseURL(serviceName, serviceURL string) (*url.URL, error) {
parsedMgmtURL, err := url.ParseRequestURI(serviceURL)
if err != nil {
log.Errorf("failed parsing %s URL %s: [%s]", serviceName, serviceURL, err.Error())
return nil, err
}
if parsedMgmtURL.Scheme != "https" && parsedMgmtURL.Scheme != "http" {
return nil, fmt.Errorf(
"invalid %s URL provided %s. Supported format [http|https]://[host]:[port]",
serviceName, serviceURL)
}
if parsedMgmtURL.Port() == "" {
switch parsedMgmtURL.Scheme {
case "https":
parsedMgmtURL.Host += ":443"
case "http":
parsedMgmtURL.Host += ":80"
default:
log.Infof("unable to determine a default port for schema %s in URL %s", parsedMgmtURL.Scheme, serviceURL)
}
}
return parsedMgmtURL, err
}
// generateKey generates a new Wireguard private key
func generateKey() string {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
panic(err)
}
return key.String()
}
// don't overwrite pre-shared key if we receive asterisks from UI
func isPreSharedKeyHidden(preSharedKey *string) bool {
if preSharedKey != nil && *preSharedKey == "**********" {
return true
}
return false
}
// UpdateConfig update existing configuration according to input configuration and return with the configuration
func UpdateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
}
return update(input)
}
// UpdateOrCreateConfig reads existing config or generates a new one
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
return nil, err
}
err = util.WriteJsonWithRestrictedPermission(context.Background(), input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
err = util.EnforcePermission(input.ConfigPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
return update(input)
}
func update(input ConfigInput) (*Config, error) {
config := &Config{}
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
}
updated, err := config.apply(input)
if err != nil {
return nil, err
}
if updated {
if err := util.WriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
return config, nil
}
// GetConfig read config file and return with Config and if it was created. Errors out if it does not exist
func GetConfig(configPath string) (*Config, error) {
return readConfig(configPath, false)
}
// UpdateOldManagementURL checks whether client can switch to the new Management URL with port 443 and the management domain.
// If it can switch, then it updates the config and returns a new one. Otherwise, it returns the provided config.
// The check is performed only for the NetBird's managed version.
func UpdateOldManagementURL(ctx context.Context, config *Config, configPath string) (*Config, error) {
defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL)
if err != nil {
return nil, err
}
parsedOldDefaultManagementURL, err := parseURL("Management URL", oldDefaultManagementURL)
if err != nil {
return nil, err
}
if config.ManagementURL.Hostname() != defaultManagementURL.Hostname() &&
config.ManagementURL.Hostname() != parsedOldDefaultManagementURL.Hostname() {
// only do the check for the NetBird's managed version
return config, nil
}
var mgmTlsEnabled bool
if config.ManagementURL.Scheme == "https" {
mgmTlsEnabled = true
}
if !mgmTlsEnabled {
// only do the check for HTTPs scheme (the hosted version of the Management service is always HTTPs)
return config, nil
}
if config.ManagementURL.Port() != managementLegacyPortString &&
config.ManagementURL.Hostname() == defaultManagementURL.Hostname() {
return config, nil
}
newURL, err := parseURL("Management URL", fmt.Sprintf("%s://%s:%d",
config.ManagementURL.Scheme, defaultManagementURL.Hostname(), 443))
if err != nil {
return nil, err
}
// here we check whether we could switch from the legacy 33073 port to the new 443
log.Infof("attempting to switch from the legacy Management URL %s to the new one %s",
config.ManagementURL.String(), newURL.String())
key, err := wgtypes.ParseKey(config.PrivateKey)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
client, err := newMgmProber(ctx, newURL.Host, key, mgmTlsEnabled)
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, err
}
defer func() {
if err := client.Close(); err != nil {
log.Warnf("failed to close the Management service client %v", err)
}
}()
// gRPC check
if err = client.HealthCheck(); err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return nil, err
}
// everything is alright => update the config
newConfig, err := UpdateConfig(ConfigInput{
ManagementURL: newURL.String(),
ConfigPath: configPath,
})
if err != nil {
log.Infof("couldn't switch to the new Management %s", newURL.String())
return config, fmt.Errorf("failed updating config file: %v", err)
}
log.Infof("successfully switched to the new Management URL: %s", newURL.String())
return newConfig, nil
}
// CreateInMemoryConfig generate a new config but do not write out it to the store
func CreateInMemoryConfig(input ConfigInput) (*Config, error) {
return createNewConfig(input)
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func ReadConfig(configPath string) (*Config, error) {
return readConfig(configPath, true)
}
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
configExists, err := fileExists(configPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if configExists {
err := util.EnforcePermission(configPath)
if err != nil {
log.Errorf("failed to enforce permission on config dir: %v", err)
}
config := &Config{}
if _, err := util.ReadJson(configPath, config); err != nil {
return nil, err
}
// initialize through apply() without changes
if changed, err := config.apply(ConfigInput{}); err != nil {
return nil, err
} else if changed {
if err = WriteOutConfig(configPath, config); err != nil {
return nil, err
}
}
return config, nil
} else if !createIfMissing {
return nil, fmt.Errorf("config file %s does not exist", configPath)
}
cfg, err := createNewConfig(ConfigInput{ConfigPath: configPath})
if err != nil {
return nil, err
}
err = WriteOutConfig(configPath, cfg)
return cfg, err
}
// WriteOutConfig write put the prepared config to the given path
func WriteOutConfig(path string, config *Config) error {
return util.WriteJson(context.Background(), path, config)
}
// DirectWriteOutConfig writes config directly without atomic temp file operations.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectWriteOutConfig(path string, config *Config) error {
return util.DirectWriteJson(context.Background(), path, config)
}
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
configExists, err := fileExists(input.ConfigPath)
if err != nil {
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
}
if !configExists {
log.Infof("generating new config %s", input.ConfigPath)
cfg, err := createNewConfig(input)
if err != nil {
return nil, err
}
err = util.DirectWriteJson(context.Background(), input.ConfigPath, cfg)
return cfg, err
}
if isPreSharedKeyHidden(input.PreSharedKey) {
input.PreSharedKey = nil
}
// Enforce permissions on existing config files (same as UpdateOrCreateConfig)
if err := util.EnforcePermission(input.ConfigPath); err != nil {
log.Errorf("failed to enforce permission on config file: %v", err)
}
return directUpdate(input)
}
func directUpdate(input ConfigInput) (*Config, error) {
config := &Config{}
if _, err := util.ReadJson(input.ConfigPath, config); err != nil {
return nil, err
}
updated, err := config.apply(input)
if err != nil {
return nil, err
}
if updated {
if err := util.DirectWriteJson(context.Background(), input.ConfigPath, config); err != nil {
return nil, err
}
}
return config, nil
}
// ConfigToJSON serializes a Config struct to a JSON string.
// This is useful for exporting config to alternative storage mechanisms
// (e.g., UserDefaults on tvOS where file writes are blocked).
func ConfigToJSON(config *Config) (string, error) {
bs, err := json.MarshalIndent(config, "", " ")
if err != nil {
return "", err
}
return string(bs), nil
}
// ConfigFromJSON deserializes a JSON string to a Config struct.
// This is useful for restoring config from alternative storage mechanisms.
// After unmarshaling, defaults are applied to ensure the config is fully initialized.
func ConfigFromJSON(jsonStr string) (*Config, error) {
config := &Config{}
err := json.Unmarshal([]byte(jsonStr), config)
if err != nil {
return nil, err
}
// Apply defaults to ensure required fields are initialized.
// This mirrors what readConfig does after loading from file.
if _, err := config.apply(ConfigInput{}); err != nil {
return nil, fmt.Errorf("failed to apply defaults to config: %w", err)
}
return config, nil
}