diff --git a/README.md b/README.md index 82ff42a..413d353 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ When Newt receives WireGuard control messages, it will use the information encod - `interface` (optional): Name of the WireGuard interface. Default: newt - `keep-interface` (optional): Keep the WireGuard interface. Default: false - `blueprint-file` (optional): Path to blueprint file to define Pangolin resources and configurations. +- `no-cloud` (optional): Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false ## Environment Variables @@ -86,6 +87,7 @@ All CLI arguments can be set using environment variables as an alternative to co - `KEEP_INTERFACE`: Keep the WireGuard interface after shutdown. Default: false (equivalent to `--keep-interface`) - `CONFIG_FILE`: Load the config json from this file instead of in the home folder. - `BLUEPRINT_FILE`: Path to blueprint file to define Pangolin resources and configurations. (equivalent to `--blueprint-file`) +- `NO_CLOUD`: Don't fail over to the cloud when using managed nodes in Pangolin Cloud. Default: false (equivalent to `--no-cloud`) ## Loading secrets from files diff --git a/main.go b/main.go index fb31cfe..b6ccc94 100644 --- a/main.go +++ b/main.go @@ -121,6 +121,7 @@ var ( healthMonitor *healthcheck.Monitor enforceHealthcheckCert bool blueprintFile string + noCloud bool // New mTLS configuration variables tlsClientCert string @@ -143,15 +144,13 @@ func main() { interfaceName = os.Getenv("INTERFACE") generateAndSaveKeyTo = os.Getenv("GENERATE_AND_SAVE_KEY_TO") keepInterfaceEnv := os.Getenv("KEEP_INTERFACE") - acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS") - useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE") - enforceHealthcheckCertEnv := os.Getenv("ENFORCE_HC_CERT") - keepInterface = keepInterfaceEnv == "true" + acceptClientsEnv := os.Getenv("ACCEPT_CLIENTS") acceptClients = acceptClientsEnv == "true" + useNativeInterfaceEnv := os.Getenv("USE_NATIVE_INTERFACE") useNativeInterface = useNativeInterfaceEnv == "true" + enforceHealthcheckCertEnv := os.Getenv("ENFORCE_HC_CERT") enforceHealthcheckCert = enforceHealthcheckCertEnv == "true" - dockerSocket = os.Getenv("DOCKER_SOCKET") pingIntervalStr := os.Getenv("PING_INTERVAL") pingTimeoutStr := os.Getenv("PING_TIMEOUT") @@ -179,6 +178,8 @@ func main() { tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") } blueprintFile = os.Getenv("BLUEPRINT_FILE") + noCloudEnv := os.Getenv("NO_CLOUD") + noCloud = noCloudEnv == "true" if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -281,6 +282,9 @@ func main() { if blueprintFile == "" { flag.StringVar(&blueprintFile, "blueprint-file", "", "Path to blueprint file (if unset, no blueprint will be applied)") } + if noCloudEnv == "" { + flag.BoolVar(&noCloud, "no-cloud", false, "Disable cloud failover") + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -635,7 +639,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } // Request exit nodes from the server - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ + "noCloud": noCloud, + }, 3*time.Second) logger.Info("Tunnel destroyed, ready for reconnection") }) @@ -1237,8 +1243,10 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub if stopFunc != nil { stopFunc() } - // request from the server the list of nodes to ping at newt/ping/request - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + // request from the server the list of nodes to ping + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{ + "noCloud": noCloud, + }, 3*time.Second) logger.Debug("Requesting exit nodes from server") clientsOnConnect() } diff --git a/websocket/client.go b/websocket/client.go index 0c0664a..c580f0e 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -37,6 +37,7 @@ type Client struct { writeMux sync.Mutex clientType string // Type of client (e.g., "newt", "olm") tlsConfig TLSConfig + configNeedsSave bool // Flag to track if config needs to be saved } type ClientOption func(*Client) diff --git a/websocket/config.go b/websocket/config.go index 1d821b6..72c9164 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -35,15 +35,25 @@ func getConfigPath(clientType string) string { } func (c *Client) loadConfig() error { + originalConfig := *c.config // Store original config to detect changes + configPath := getConfigPath(c.clientType) + if c.config.ID != "" && c.config.Secret != "" && c.config.Endpoint != "" { logger.Debug("Config already provided, skipping loading from file") + // Check if config file exists, if not, we should save it + if _, err := os.Stat(configPath); os.IsNotExist(err) { + logger.Info("Config file does not exist at %s, will create it", configPath) + c.configNeedsSave = true + } return nil } - configPath := getConfigPath(c.clientType) + logger.Info("Loading config from: %s", configPath) data, err := os.ReadFile(configPath) if err != nil { if os.IsNotExist(err) { + logger.Info("Config file does not exist at %s, will create it with provided values", configPath) + c.configNeedsSave = true return nil } return err @@ -54,6 +64,12 @@ func (c *Client) loadConfig() error { return err } + // Track what was loaded from file vs provided by CLI + fileHadID := c.config.ID == "" + fileHadSecret := c.config.Secret == "" + fileHadCert := c.config.TlsClientCert == "" + fileHadEndpoint := c.config.Endpoint == "" + if c.config.ID == "" { c.config.ID = config.ID } @@ -68,6 +84,15 @@ func (c *Client) loadConfig() error { c.baseURL = config.Endpoint } + // Check if CLI args provided values that override file values + if (!fileHadID && originalConfig.ID != "") || + (!fileHadSecret && originalConfig.Secret != "") || + (!fileHadCert && originalConfig.TlsClientCert != "") || + (!fileHadEndpoint && originalConfig.Endpoint != "") { + logger.Info("CLI arguments provided, config will be updated") + c.configNeedsSave = true + } + logger.Debug("Loaded config from %s", configPath) logger.Debug("Config: %+v", c.config) @@ -75,10 +100,21 @@ func (c *Client) loadConfig() error { } func (c *Client) saveConfig() error { + if !c.configNeedsSave { + logger.Debug("Config has not changed, skipping save") + return nil + } + configPath := getConfigPath(c.clientType) data, err := json.MarshalIndent(c.config, "", " ") if err != nil { return err } - return os.WriteFile(configPath, data, 0644) + + logger.Info("Saving config to: %s", configPath) + err = os.WriteFile(configPath, data, 0644) + if err == nil { + c.configNeedsSave = false // Reset flag after successful save + } + return err } diff --git a/wg/wg.go b/wg/wg.go index 5a512d6..a14e2c3 100644 --- a/wg/wg.go +++ b/wg/wg.go @@ -952,22 +952,30 @@ func (s *WireGuardService) encryptPayload(payload []byte) (interface{}, error) { } func (s *WireGuardService) keepSendingUDPHolePunch(host string) { + logger.Info("Starting UDP hole punch routine to %s:21820", host) + // send initial hole punch if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Error("Failed to send initial UDP hole punch: %v", err) + logger.Debug("Failed to send initial UDP hole punch: %v", err) } ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + for { select { case <-s.stopHolepunch: logger.Info("Stopping UDP holepunch") return + case <-timeout.C: + logger.Info("UDP holepunch routine timed out after 15 seconds") + return case <-ticker.C: if err := s.sendUDPHolePunch(host + ":21820"); err != nil { - logger.Error("Failed to send UDP hole punch: %v", err) + logger.Debug("Failed to send UDP hole punch: %v", err) } } } diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index 08d740e..f6708e9 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -1076,11 +1076,17 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) { ticker := time.NewTicker(3 * time.Second) defer ticker.Stop() + timeout := time.NewTimer(15 * time.Second) + defer timeout.Stop() + for { select { case <-s.stopHolepunch: logger.Info("Stopping UDP holepunch") return + case <-timeout.C: + logger.Info("UDP holepunch routine timed out after 15 seconds") + return case <-ticker.C: if err := s.sendUDPHolePunch(host + ":21820"); err != nil { logger.Debug("Failed to send UDP hole punch: %v", err)