From ceee421a0542133019d90b3a09b6bb4cb805e044 Mon Sep 17 00:00:00 2001 From: Krzysztof Nazarewski Date: Wed, 8 May 2024 18:58:31 +0200 Subject: [PATCH] unify Config generation, loading and updating (#1586) * config.go: pull unified Config.apply() out of createNewConfig() and update() as a bonus it ensures returned Config object doesn't have any configuration values missing --- client/internal/config.go | 352 +++++++++++++++++++------------------- 1 file changed, 180 insertions(+), 172 deletions(-) diff --git a/client/internal/config.go b/client/internal/config.go index 2beb853c0..66721cd21 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -5,6 +5,8 @@ import ( "fmt" "net/url" "os" + "reflect" + "strings" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -102,6 +104,14 @@ func ReadConfig(configPath string) (*Config, error) { if _, err := util.ReadJson(configPath, config); err != nil { return nil, err } + // initialize through apply() without changes + if changed, err := config.apply(ConfigInput{}); err != nil { + return nil, err + } else if changed { + if err = WriteOutConfig(configPath, config); err != nil { + return nil, err + } + } return config, nil } @@ -154,83 +164,15 @@ func WriteOutConfig(path string, config *Config) error { // createNewConfig creates a new config generating a new Wireguard key and saving to file func createNewConfig(input ConfigInput) (*Config, error) { - wgKey := generateKey() - pem, err := ssh.GeneratePrivateKey(ssh.ED25519) - if err != nil { - return nil, err - } - config := &Config{ - SSHKey: string(pem), - PrivateKey: wgKey, - IFaceBlackList: []string{}, - DisableIPv6Discovery: false, - NATExternalIPs: input.NATExternalIPs, - CustomDNSAddress: string(input.CustomDNSAddress), - ServerSSHAllowed: util.False(), - DisableAutoConnect: false, + // defaults to false only for new (post 0.26) configurations + ServerSSHAllowed: util.False(), } - defaultManagementURL, err := parseURL("Management URL", DefaultManagementURL) - if err != nil { + if _, err := config.apply(input); err != nil { return nil, err } - config.ManagementURL = defaultManagementURL - if input.ManagementURL != "" { - URL, err := parseURL("Management URL", input.ManagementURL) - if err != nil { - return nil, err - } - config.ManagementURL = URL - } - - config.WgPort = iface.DefaultWgPort - if input.WireguardPort != nil { - config.WgPort = *input.WireguardPort - } - - if input.NetworkMonitor != nil { - config.NetworkMonitor = *input.NetworkMonitor - } - - config.WgIface = iface.WgInterfaceDefault - if input.InterfaceName != nil { - config.WgIface = *input.InterfaceName - } - - if input.PreSharedKey != nil { - config.PreSharedKey = *input.PreSharedKey - } - - if input.RosenpassEnabled != nil { - config.RosenpassEnabled = *input.RosenpassEnabled - } - - if input.RosenpassPermissive != nil { - config.RosenpassPermissive = *input.RosenpassPermissive - } - - if input.ServerSSHAllowed != nil { - config.ServerSSHAllowed = input.ServerSSHAllowed - } - - defaultAdminURL, err := parseURL("Admin URL", DefaultAdminURL) - if err != nil { - return nil, err - } - - config.AdminURL = defaultAdminURL - if input.AdminURL != "" { - newURL, err := parseURL("Admin Panel URL", input.AdminURL) - if err != nil { - return nil, err - } - config.AdminURL = newURL - } - - // nolint:gocritic - config.IFaceBlackList = append(defaultInterfaceBlacklist, input.ExtraIFaceBlackList...) return config, nil } @@ -241,109 +183,12 @@ func update(input ConfigInput) (*Config, error) { return nil, err } - refresh := false - - if input.ManagementURL != "" && config.ManagementURL.String() != input.ManagementURL { - log.Infof("new Management URL provided, updated to %s (old value %s)", - input.ManagementURL, config.ManagementURL) - newURL, err := parseURL("Management URL", input.ManagementURL) - if err != nil { - return nil, err - } - config.ManagementURL = newURL - refresh = true + updated, err := config.apply(input) + if err != nil { + return nil, err } - if input.AdminURL != "" && (config.AdminURL == nil || config.AdminURL.String() != input.AdminURL) { - log.Infof("new Admin Panel URL provided, updated to %s (old value %s)", - input.AdminURL, config.AdminURL) - newURL, err := parseURL("Admin Panel URL", input.AdminURL) - if err != nil { - return nil, err - } - config.AdminURL = newURL - refresh = true - } - - if input.PreSharedKey != nil && config.PreSharedKey != *input.PreSharedKey { - log.Infof("new pre-shared key provided, replacing old key") - config.PreSharedKey = *input.PreSharedKey - refresh = true - } - - if config.SSHKey == "" { - pem, err := ssh.GeneratePrivateKey(ssh.ED25519) - if err != nil { - return nil, err - } - config.SSHKey = string(pem) - refresh = true - } - - if config.WgPort == 0 { - config.WgPort = iface.DefaultWgPort - refresh = true - } - - if input.NetworkMonitor != nil { - config.NetworkMonitor = *input.NetworkMonitor - refresh = true - } - - if input.WireguardPort != nil { - config.WgPort = *input.WireguardPort - refresh = true - } - - if input.InterfaceName != nil { - config.WgIface = *input.InterfaceName - refresh = true - } - - if input.NATExternalIPs != nil && len(config.NATExternalIPs) != len(input.NATExternalIPs) { - config.NATExternalIPs = input.NATExternalIPs - refresh = true - } - - if input.CustomDNSAddress != nil { - config.CustomDNSAddress = string(input.CustomDNSAddress) - refresh = true - } - - if input.RosenpassEnabled != nil { - config.RosenpassEnabled = *input.RosenpassEnabled - refresh = true - } - - if input.RosenpassPermissive != nil { - config.RosenpassPermissive = *input.RosenpassPermissive - refresh = true - } - - if input.DisableAutoConnect != nil { - config.DisableAutoConnect = *input.DisableAutoConnect - refresh = true - } - - if input.ServerSSHAllowed != nil { - config.ServerSSHAllowed = input.ServerSSHAllowed - refresh = true - } - - if config.ServerSSHAllowed == nil { - config.ServerSSHAllowed = util.True() - refresh = true - } - - if len(input.ExtraIFaceBlackList) > 0 { - for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) { - config.IFaceBlackList = append(config.IFaceBlackList, iFace) - refresh = true - } - } - - if refresh { - // since we have new management URL, we need to update config file + if updated { if err := util.WriteJson(input.ConfigPath, config); err != nil { return nil, err } @@ -352,6 +197,169 @@ func update(input ConfigInput) (*Config, error) { return config, nil } +func (config *Config) apply(input ConfigInput) (updated bool, err error) { + if config.ManagementURL == nil { + log.Infof("using default Management URL %s", DefaultManagementURL) + config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL) + if err != nil { + return false, err + } + } + if input.ManagementURL != "" && input.ManagementURL != config.ManagementURL.String() { + log.Infof("new Management URL provided, updated to %#v (old value %#v)", + input.ManagementURL, config.ManagementURL.String()) + URL, err := parseURL("Management URL", input.ManagementURL) + if err != nil { + return false, err + } + config.ManagementURL = URL + updated = true + } else if config.ManagementURL == nil { + log.Infof("using default Management URL %s", DefaultManagementURL) + config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL) + if err != nil { + return false, err + } + } + + if config.AdminURL == nil { + log.Infof("using default Admin URL %s", DefaultManagementURL) + config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL) + if err != nil { + return false, err + } + } + if input.AdminURL != "" && input.AdminURL != config.AdminURL.String() { + log.Infof("new Admin Panel URL provided, updated to %#v (old value %#v)", + input.AdminURL, config.AdminURL.String()) + newURL, err := parseURL("Admin Panel URL", input.AdminURL) + if err != nil { + return updated, err + } + config.AdminURL = newURL + updated = true + } + + if config.PrivateKey == "" { + log.Infof("generated new Wireguard key") + config.PrivateKey = generateKey() + updated = true + } + + if config.SSHKey == "" { + log.Infof("generated new SSH key") + pem, err := ssh.GeneratePrivateKey(ssh.ED25519) + if err != nil { + return false, err + } + config.SSHKey = string(pem) + updated = true + } + + if input.WireguardPort != nil && *input.WireguardPort != config.WgPort { + log.Infof("updating Wireguard port %d (old value %d)", + *input.WireguardPort, config.WgPort) + config.WgPort = *input.WireguardPort + updated = true + } else if config.WgPort == 0 { + config.WgPort = iface.DefaultWgPort + log.Infof("using default Wireguard port %d", config.WgPort) + updated = true + } + + if input.InterfaceName != nil && *input.InterfaceName != config.WgIface { + log.Infof("updating Wireguard interface %#v (old value %#v)", + *input.InterfaceName, config.WgIface) + config.WgIface = *input.InterfaceName + updated = true + } else if config.WgIface == "" { + config.WgIface = iface.WgInterfaceDefault + log.Infof("using default Wireguard interface %s", config.WgIface) + updated = true + } + + if input.NATExternalIPs != nil && !reflect.DeepEqual(config.NATExternalIPs, input.NATExternalIPs) { + log.Infof("updating NAT External IP [ %s ] (old value: [ %s ])", + strings.Join(input.NATExternalIPs, " "), + strings.Join(config.NATExternalIPs, " ")) + config.NATExternalIPs = input.NATExternalIPs + updated = true + } + + if input.PreSharedKey != nil && *input.PreSharedKey != config.PreSharedKey { + log.Infof("new pre-shared key provided, replacing old key") + config.PreSharedKey = *input.PreSharedKey + updated = true + } + + if input.RosenpassEnabled != nil && *input.RosenpassEnabled != config.RosenpassEnabled { + log.Infof("switching Rosenpass to %t", *input.RosenpassEnabled) + config.RosenpassEnabled = *input.RosenpassEnabled + updated = true + } + + if input.RosenpassPermissive != nil && *input.RosenpassPermissive != config.RosenpassPermissive { + log.Infof("switching Rosenpass permissive to %t", *input.RosenpassPermissive) + config.RosenpassPermissive = *input.RosenpassPermissive + updated = true + } + + if input.NetworkMonitor != nil && *input.NetworkMonitor != config.NetworkMonitor { + log.Infof("switching Network Monitor to %t", *input.NetworkMonitor) + config.NetworkMonitor = *input.NetworkMonitor + updated = true + } + + if input.CustomDNSAddress != nil && string(input.CustomDNSAddress) != config.CustomDNSAddress { + log.Infof("updating custom DNS address %#v (old value %#v)", + string(input.CustomDNSAddress), config.CustomDNSAddress) + config.CustomDNSAddress = string(input.CustomDNSAddress) + updated = true + } + + if len(config.IFaceBlackList) == 0 { + log.Infof("filling in interface blacklist with defaults: [ %s ]", + strings.Join(defaultInterfaceBlacklist, " ")) + config.IFaceBlackList = append(config.IFaceBlackList, defaultInterfaceBlacklist...) + updated = true + } + + if len(input.ExtraIFaceBlackList) > 0 { + for _, iFace := range util.SliceDiff(input.ExtraIFaceBlackList, config.IFaceBlackList) { + log.Infof("adding new entry to interface blacklist: %s", iFace) + config.IFaceBlackList = append(config.IFaceBlackList, iFace) + updated = true + } + } + + if input.DisableAutoConnect != nil && *input.DisableAutoConnect != config.DisableAutoConnect { + if *input.DisableAutoConnect { + log.Infof("turning off automatic connection on startup") + } else { + log.Infof("enabling automatic connection on startup") + } + config.DisableAutoConnect = *input.DisableAutoConnect + updated = true + } + + if input.ServerSSHAllowed != nil && *input.ServerSSHAllowed != *config.ServerSSHAllowed { + if *input.ServerSSHAllowed { + log.Infof("enabling SSH server") + } else { + log.Infof("disabling SSH server") + } + config.ServerSSHAllowed = input.ServerSSHAllowed + updated = true + } else if config.ServerSSHAllowed == nil { + // enables SSH for configs from old versions to preserve backwards compatibility + log.Infof("falling back to enabled SSH server for pre-existing configuration") + config.ServerSSHAllowed = util.True() + updated = true + } + + return updated, nil +} + // parseURL parses and validates a service URL func parseURL(serviceName, serviceURL string) (*url.URL, error) { parsedMgmtURL, err := url.ParseRequestURI(serviceURL)