diff --git a/config.go b/config.go index 6a87d94..4b6510a 100644 --- a/config.go +++ b/config.go @@ -43,6 +43,7 @@ type OlmConfig struct { Holepunch bool `json:"holepunch"` TlsClientCert string `json:"tlsClientCert"` OverrideDNS bool `json:"overrideDNS"` + DisableRelay bool `json:"disableRelay"` // DoNotCreateNewClient bool `json:"doNotCreateNewClient"` // Parsed values (not in JSON) @@ -104,6 +105,7 @@ func DefaultConfig() *OlmConfig { config.sources["pingTimeout"] = string(SourceDefault) config.sources["holepunch"] = string(SourceDefault) config.sources["overrideDNS"] = string(SourceDefault) + config.sources["disableRelay"] = string(SourceDefault) // config.sources["doNotCreateNewClient"] = string(SourceDefault) return config @@ -259,6 +261,10 @@ func loadConfigFromEnv(config *OlmConfig) { config.OverrideDNS = true config.sources["overrideDNS"] = string(SourceEnv) } + if val := os.Getenv("DISABLE_RELAY"); val == "true" { + config.DisableRelay = true + config.sources["disableRelay"] = string(SourceEnv) + } // if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" { // config.DoNotCreateNewClient = true // config.sources["doNotCreateNewClient"] = string(SourceEnv) @@ -288,6 +294,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { "enableApi": config.EnableAPI, "holepunch": config.Holepunch, "overrideDNS": config.OverrideDNS, + "disableRelay": config.DisableRelay, // "doNotCreateNewClient": config.DoNotCreateNewClient, } @@ -310,6 +317,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests") serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching") serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings") + serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections") // serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client") version := serviceFlags.Bool("version", false, "Print the version") @@ -382,6 +390,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) { if config.OverrideDNS != origValues["overrideDNS"].(bool) { config.sources["overrideDNS"] = string(SourceCLI) } + if config.DisableRelay != origValues["disableRelay"].(bool) { + config.sources["disableRelay"] = string(SourceCLI) + } // if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) { // config.sources["doNotCreateNewClient"] = string(SourceCLI) // } @@ -502,6 +513,10 @@ func mergeConfigs(dest, src *OlmConfig) { dest.OverrideDNS = src.OverrideDNS dest.sources["overrideDNS"] = string(SourceFile) } + if src.DisableRelay { + dest.DisableRelay = src.DisableRelay + dest.sources["disableRelay"] = string(SourceFile) + } // if src.DoNotCreateNewClient { // dest.DoNotCreateNewClient = src.DoNotCreateNewClient // dest.sources["doNotCreateNewClient"] = string(SourceFile) @@ -591,6 +606,7 @@ func (c *OlmConfig) ShowConfig() { fmt.Println("\nAdvanced:") fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch")) fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS")) + fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay")) // fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient")) if c.TlsClientCert != "" { fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert")) diff --git a/main.go b/main.go index 630e7a1..572886f 100644 --- a/main.go +++ b/main.go @@ -234,8 +234,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { PingTimeoutDuration: config.PingTimeoutDuration, OrgID: config.OrgID, OverrideDNS: config.OverrideDNS, + DisableRelay: config.DisableRelay, EnableUAPI: true, - DisableRelay: false, // allow it to relay } go olm.StartTunnel(tunnelConfig) } else { diff --git a/olm/olm.go b/olm/olm.go index 0c8a50c..ddc4e88 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -45,6 +45,7 @@ var ( holePunchManager *holepunch.Manager peerMonitor *peermonitor.PeerMonitor globalConfig GlobalConfig + tunnelConfig TunnelConfig globalCtx context.Context stopRegister func() stopPeerSend func() @@ -99,7 +100,7 @@ func Init(ctx context.Context, config GlobalConfig) { globalConfig = config globalCtx = ctx - // Create a cancellable context for internal shutdown control + // Create a cancellable context for internal shutdown controconfiguration GlobalConfigl ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -209,6 +210,7 @@ func StartTunnel(config TunnelConfig) { } tunnelRunning = true // Also set it here in case it is called externally + tunnelConfig = config // Reset terminated status when tunnel starts apiServer.SetTerminated(false) @@ -245,7 +247,8 @@ func StartTunnel(config TunnelConfig) { id, // Use provided ID secret, // Use provided secret userToken, // Use provided user token OPTIONAL - endpoint, // Use provided endpoint + config.OrgID, + endpoint, // Use provided endpoint config.PingIntervalDuration, config.PingTimeoutDuration, ) @@ -1000,38 +1003,18 @@ func GetStatus() api.StatusResponse { func SwitchOrg(orgID string) error { logger.Info("Processing org switch request to orgId: %s", orgID) - - // Ensure we have an active olmClient - if olmClient == nil { - return fmt.Errorf("no active connection to switch organizations") + // stop the tunnel + if err := StopTunnel(); err != nil { + return fmt.Errorf("failed to stop existing tunnel: %w", err) } - // Update the orgID in the API server + // Update the org ID in the API server and global config apiServer.SetOrgID(orgID) - // Mark as not connected to trigger re-registration - connected = false + tunnelConfig.OrgID = orgID - // Close existing tunnel resources (but keep websocket alive) - Close() - - // Recreate sharedBind and holepunch manager - needed because Close() releases them - if err := initTunnelInfo(olmClient.GetConfig().ID); err != nil { - return err - } - - // Clear peer statuses in API - apiServer.SetRegistered(false) - - // Trigger re-registration with new orgId - logger.Info("Re-registering with new orgId: %s", orgID) - publicKey := privateKey.PublicKey() - stopRegister, updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "relay": true, // Default to relay mode for org switch - "olmVersion": globalConfig.Version, - "orgId": orgID, - }, 1*time.Second) + // Restart the tunnel with the same config but new org ID + go StartTunnel(tunnelConfig) return nil } diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index 4233238..dcdd1d9 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -73,7 +73,7 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w callback: callback, interval: 1 * time.Second, // Default check interval timeout: 2500 * time.Millisecond, - maxAttempts: 8, + maxAttempts: 15, privateKey: privateKey, wsClient: wsClient, device: device, diff --git a/websocket/client.go b/websocket/client.go index 74970a3..54b659a 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -62,6 +62,7 @@ type Config struct { Endpoint string TlsClientCert string // legacy PKCS12 file path UserToken string // optional user token for websocket authentication + OrgID string // optional organization ID for websocket authentication } type Client struct { @@ -131,12 +132,13 @@ func (c *Client) OnAuthError(callback func(statusCode int, message string)) { } // NewClient creates a new websocket client -func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { +func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { config := &Config{ ID: ID, Secret: secret, Endpoint: endpoint, UserToken: userToken, + OrgID: orgId, } client := &Client{ @@ -321,11 +323,10 @@ func (c *Client) getToken() (string, []ExitNode, error) { logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") } - var tokenData map[string]interface{} - - tokenData = map[string]interface{}{ + tokenData := map[string]interface{}{ "olmId": c.config.ID, "secret": c.config.Secret, + "orgId": c.config.OrgID, } jsonData, err := json.Marshal(tokenData)