diff --git a/config.go b/config.go index 1f7f0d4..4364a78 100644 --- a/config.go +++ b/config.go @@ -38,8 +38,9 @@ type OlmConfig struct { PingTimeout string `json:"pingTimeout"` // Advanced - Holepunch bool `json:"holepunch"` - TlsClientCert string `json:"tlsClientCert"` + Holepunch bool `json:"holepunch"` + TlsClientCert string `json:"tlsClientCert"` + DoNotCreateNewClient bool `json:"doNotCreateNewClient"` // Parsed values (not in JSON) PingIntervalDuration time.Duration `json:"-"` @@ -73,16 +74,17 @@ func DefaultConfig() *OlmConfig { } config := &OlmConfig{ - MTU: 1280, - DNS: "8.8.8.8", - LogLevel: "INFO", - InterfaceName: "olm", - EnableAPI: false, - SocketPath: socketPath, - PingInterval: "3s", - PingTimeout: "5s", - Holepunch: false, - sources: make(map[string]string), + MTU: 1280, + DNS: "8.8.8.8", + LogLevel: "INFO", + InterfaceName: "olm", + EnableAPI: false, + SocketPath: socketPath, + PingInterval: "3s", + PingTimeout: "5s", + Holepunch: false, + DoNotCreateNewClient: false, + sources: make(map[string]string), } // Track default sources @@ -96,6 +98,7 @@ func DefaultConfig() *OlmConfig { config.sources["pingInterval"] = string(SourceDefault) config.sources["pingTimeout"] = string(SourceDefault) config.sources["holepunch"] = string(SourceDefault) + config.sources["doNotCreateNewClient"] = string(SourceDefault) return config } @@ -242,6 +245,10 @@ func loadConfigFromEnv(config *OlmConfig) { config.Holepunch = true config.sources["holepunch"] = string(SourceEnv) } + if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" { + config.DoNotCreateNewClient = true + config.sources["doNotCreateNewClient"] = string(SourceEnv) + } } // loadConfigFromCLI loads configuration from command-line arguments @@ -250,21 +257,22 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { // Store original values to detect changes origValues := map[string]interface{}{ - "endpoint": config.Endpoint, - "id": config.ID, - "secret": config.Secret, - "org": config.OrgID, - "userToken": config.UserToken, - "mtu": config.MTU, - "dns": config.DNS, - "logLevel": config.LogLevel, - "interface": config.InterfaceName, - "httpAddr": config.HTTPAddr, - "socketPath": config.SocketPath, - "pingInterval": config.PingInterval, - "pingTimeout": config.PingTimeout, - "enableApi": config.EnableAPI, - "holepunch": config.Holepunch, + "endpoint": config.Endpoint, + "id": config.ID, + "secret": config.Secret, + "org": config.OrgID, + "userToken": config.UserToken, + "mtu": config.MTU, + "dns": config.DNS, + "logLevel": config.LogLevel, + "interface": config.InterfaceName, + "httpAddr": config.HTTPAddr, + "socketPath": config.SocketPath, + "pingInterval": config.PingInterval, + "pingTimeout": config.PingTimeout, + "enableApi": config.EnableAPI, + "holepunch": config.Holepunch, + "doNotCreateNewClient": config.DoNotCreateNewClient, } // Define flags @@ -283,6 +291,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { serviceFlags.StringVar(&config.PingTimeout, "ping-timeout", config.PingTimeout, "Timeout for each ping") serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests") serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching") + serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client") version := serviceFlags.Bool("version", false, "Print the version") showConfig := serviceFlags.Bool("show-config", false, "Show configuration sources and exit") @@ -338,6 +347,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.Holepunch != origValues["holepunch"].(bool) { config.sources["holepunch"] = string(SourceCLI) } + if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) { + config.sources["doNotCreateNewClient"] = string(SourceCLI) + } return *version, *showConfig, nil } @@ -447,6 +459,10 @@ func mergeConfigs(dest, src *OlmConfig) { dest.Holepunch = src.Holepunch dest.sources["holepunch"] = string(SourceFile) } + if src.DoNotCreateNewClient { + dest.DoNotCreateNewClient = src.DoNotCreateNewClient + dest.sources["doNotCreateNewClient"] = string(SourceFile) + } } // SaveConfig saves the current configuration to the config file @@ -529,9 +545,10 @@ func (c *OlmConfig) ShowConfig() { // Advanced fmt.Println("\nAdvanced:") - fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch")) + fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch")) + fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient")) if c.TlsClientCert != "" { - fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert")) + fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert")) } // Source legend diff --git a/main.go b/main.go index 5b1b60f..80d81df 100644 --- a/main.go +++ b/main.go @@ -209,6 +209,7 @@ func main() { PingTimeoutDuration: config.PingTimeoutDuration, Version: config.Version, OrgID: config.OrgID, + DoNotCreateNewClient: config.DoNotCreateNewClient, } // Create a context that will be cancelled on interrupt signals diff --git a/olm/olm.go b/olm/olm.go index b5f0e51..895acd9 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -50,8 +50,9 @@ type Config struct { // Source tracking (not in JSON) sources map[string]string - Version string - OrgID string + Version string + OrgID string + DoNotCreateNewClient bool } var ( @@ -709,10 +710,11 @@ func TunnelProcess(ctx context.Context, config Config, id string, secret string, if stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": !config.Holepunch, - "olmVersion": config.Version, - "orgId": config.OrgID, + "publicKey": publicKey.String(), + "relay": !config.Holepunch, + "olmVersion": config.Version, + "orgId": config.OrgID, + "doNotCreateNewClient": config.DoNotCreateNewClient, }, 1*time.Second) }