Make relay optional

Former-commit-id: e9e4b00994
This commit is contained in:
Owen
2025-12-01 16:19:23 -05:00
parent 7270b840cf
commit 6e4ec246ef
5 changed files with 35 additions and 35 deletions

View File

@@ -43,6 +43,7 @@ type OlmConfig struct {
Holepunch bool `json:"holepunch"` Holepunch bool `json:"holepunch"`
TlsClientCert string `json:"tlsClientCert"` TlsClientCert string `json:"tlsClientCert"`
OverrideDNS bool `json:"overrideDNS"` OverrideDNS bool `json:"overrideDNS"`
DisableRelay bool `json:"disableRelay"`
// DoNotCreateNewClient bool `json:"doNotCreateNewClient"` // DoNotCreateNewClient bool `json:"doNotCreateNewClient"`
// Parsed values (not in JSON) // Parsed values (not in JSON)
@@ -104,6 +105,7 @@ func DefaultConfig() *OlmConfig {
config.sources["pingTimeout"] = string(SourceDefault) config.sources["pingTimeout"] = string(SourceDefault)
config.sources["holepunch"] = string(SourceDefault) config.sources["holepunch"] = string(SourceDefault)
config.sources["overrideDNS"] = string(SourceDefault) config.sources["overrideDNS"] = string(SourceDefault)
config.sources["disableRelay"] = string(SourceDefault)
// config.sources["doNotCreateNewClient"] = string(SourceDefault) // config.sources["doNotCreateNewClient"] = string(SourceDefault)
return config return config
@@ -259,6 +261,10 @@ func loadConfigFromEnv(config *OlmConfig) {
config.OverrideDNS = true config.OverrideDNS = true
config.sources["overrideDNS"] = string(SourceEnv) config.sources["overrideDNS"] = string(SourceEnv)
} }
if val := os.Getenv("DISABLE_RELAY"); val == "true" {
config.DisableRelay = true
config.sources["disableRelay"] = string(SourceEnv)
}
// if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" { // if val := os.Getenv("DO_NOT_CREATE_NEW_CLIENT"); val == "true" {
// config.DoNotCreateNewClient = true // config.DoNotCreateNewClient = true
// config.sources["doNotCreateNewClient"] = string(SourceEnv) // config.sources["doNotCreateNewClient"] = string(SourceEnv)
@@ -288,6 +294,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
"enableApi": config.EnableAPI, "enableApi": config.EnableAPI,
"holepunch": config.Holepunch, "holepunch": config.Holepunch,
"overrideDNS": config.OverrideDNS, "overrideDNS": config.OverrideDNS,
"disableRelay": config.DisableRelay,
// "doNotCreateNewClient": config.DoNotCreateNewClient, // "doNotCreateNewClient": config.DoNotCreateNewClient,
} }
@@ -310,6 +317,7 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests") serviceFlags.BoolVar(&config.EnableAPI, "enable-api", config.EnableAPI, "Enable API server for receiving connection requests")
serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching") serviceFlags.BoolVar(&config.Holepunch, "holepunch", config.Holepunch, "Enable hole punching")
serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings") serviceFlags.BoolVar(&config.OverrideDNS, "override-dns", config.OverrideDNS, "Override system DNS settings")
serviceFlags.BoolVar(&config.DisableRelay, "disable-relay", config.DisableRelay, "Disable relay connections")
// serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client") // serviceFlags.BoolVar(&config.DoNotCreateNewClient, "do-not-create-new-client", config.DoNotCreateNewClient, "Do not create new client")
version := serviceFlags.Bool("version", false, "Print the version") version := serviceFlags.Bool("version", false, "Print the version")
@@ -382,6 +390,9 @@ func loadConfigFromCLI(config *OlmConfig, args []string) (bool, bool, error) {
if config.OverrideDNS != origValues["overrideDNS"].(bool) { if config.OverrideDNS != origValues["overrideDNS"].(bool) {
config.sources["overrideDNS"] = string(SourceCLI) config.sources["overrideDNS"] = string(SourceCLI)
} }
if config.DisableRelay != origValues["disableRelay"].(bool) {
config.sources["disableRelay"] = string(SourceCLI)
}
// if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) { // if config.DoNotCreateNewClient != origValues["doNotCreateNewClient"].(bool) {
// config.sources["doNotCreateNewClient"] = string(SourceCLI) // config.sources["doNotCreateNewClient"] = string(SourceCLI)
// } // }
@@ -502,6 +513,10 @@ func mergeConfigs(dest, src *OlmConfig) {
dest.OverrideDNS = src.OverrideDNS dest.OverrideDNS = src.OverrideDNS
dest.sources["overrideDNS"] = string(SourceFile) dest.sources["overrideDNS"] = string(SourceFile)
} }
if src.DisableRelay {
dest.DisableRelay = src.DisableRelay
dest.sources["disableRelay"] = string(SourceFile)
}
// if src.DoNotCreateNewClient { // if src.DoNotCreateNewClient {
// dest.DoNotCreateNewClient = src.DoNotCreateNewClient // dest.DoNotCreateNewClient = src.DoNotCreateNewClient
// dest.sources["doNotCreateNewClient"] = string(SourceFile) // dest.sources["doNotCreateNewClient"] = string(SourceFile)
@@ -591,6 +606,7 @@ func (c *OlmConfig) ShowConfig() {
fmt.Println("\nAdvanced:") fmt.Println("\nAdvanced:")
fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch")) fmt.Printf(" holepunch = %v [%s]\n", c.Holepunch, getSource("holepunch"))
fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS")) fmt.Printf(" override-dns = %v [%s]\n", c.OverrideDNS, getSource("overrideDNS"))
fmt.Printf(" disable-relay = %v [%s]\n", c.DisableRelay, getSource("disableRelay"))
// fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient")) // fmt.Printf(" do-not-create-new-client = %v [%s]\n", c.DoNotCreateNewClient, getSource("doNotCreateNewClient"))
if c.TlsClientCert != "" { if c.TlsClientCert != "" {
fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert")) fmt.Printf(" tls-cert = %s [%s]\n", c.TlsClientCert, getSource("tlsClientCert"))

View File

@@ -234,8 +234,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) {
PingTimeoutDuration: config.PingTimeoutDuration, PingTimeoutDuration: config.PingTimeoutDuration,
OrgID: config.OrgID, OrgID: config.OrgID,
OverrideDNS: config.OverrideDNS, OverrideDNS: config.OverrideDNS,
DisableRelay: config.DisableRelay,
EnableUAPI: true, EnableUAPI: true,
DisableRelay: false, // allow it to relay
} }
go olm.StartTunnel(tunnelConfig) go olm.StartTunnel(tunnelConfig)
} else { } else {

View File

@@ -45,6 +45,7 @@ var (
holePunchManager *holepunch.Manager holePunchManager *holepunch.Manager
peerMonitor *peermonitor.PeerMonitor peerMonitor *peermonitor.PeerMonitor
globalConfig GlobalConfig globalConfig GlobalConfig
tunnelConfig TunnelConfig
globalCtx context.Context globalCtx context.Context
stopRegister func() stopRegister func()
stopPeerSend func() stopPeerSend func()
@@ -99,7 +100,7 @@ func Init(ctx context.Context, config GlobalConfig) {
globalConfig = config globalConfig = config
globalCtx = ctx globalCtx = ctx
// Create a cancellable context for internal shutdown control // Create a cancellable context for internal shutdown controconfiguration GlobalConfigl
ctx, cancel := context.WithCancel(ctx) ctx, cancel := context.WithCancel(ctx)
defer cancel() defer cancel()
@@ -209,6 +210,7 @@ func StartTunnel(config TunnelConfig) {
} }
tunnelRunning = true // Also set it here in case it is called externally tunnelRunning = true // Also set it here in case it is called externally
tunnelConfig = config
// Reset terminated status when tunnel starts // Reset terminated status when tunnel starts
apiServer.SetTerminated(false) apiServer.SetTerminated(false)
@@ -245,7 +247,8 @@ func StartTunnel(config TunnelConfig) {
id, // Use provided ID id, // Use provided ID
secret, // Use provided secret secret, // Use provided secret
userToken, // Use provided user token OPTIONAL userToken, // Use provided user token OPTIONAL
endpoint, // Use provided endpoint config.OrgID,
endpoint, // Use provided endpoint
config.PingIntervalDuration, config.PingIntervalDuration,
config.PingTimeoutDuration, config.PingTimeoutDuration,
) )
@@ -1000,38 +1003,18 @@ func GetStatus() api.StatusResponse {
func SwitchOrg(orgID string) error { func SwitchOrg(orgID string) error {
logger.Info("Processing org switch request to orgId: %s", orgID) logger.Info("Processing org switch request to orgId: %s", orgID)
// stop the tunnel
// Ensure we have an active olmClient if err := StopTunnel(); err != nil {
if olmClient == nil { return fmt.Errorf("failed to stop existing tunnel: %w", err)
return fmt.Errorf("no active connection to switch organizations")
} }
// Update the orgID in the API server // Update the org ID in the API server and global config
apiServer.SetOrgID(orgID) apiServer.SetOrgID(orgID)
// Mark as not connected to trigger re-registration tunnelConfig.OrgID = orgID
connected = false
// Close existing tunnel resources (but keep websocket alive) // Restart the tunnel with the same config but new org ID
Close() go StartTunnel(tunnelConfig)
// Recreate sharedBind and holepunch manager - needed because Close() releases them
if err := initTunnelInfo(olmClient.GetConfig().ID); err != nil {
return err
}
// Clear peer statuses in API
apiServer.SetRegistered(false)
// Trigger re-registration with new orgId
logger.Info("Re-registering with new orgId: %s", orgID)
publicKey := privateKey.PublicKey()
stopRegister, updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"relay": true, // Default to relay mode for org switch
"olmVersion": globalConfig.Version,
"orgId": orgID,
}, 1*time.Second)
return nil return nil
} }

View File

@@ -73,7 +73,7 @@ func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *w
callback: callback, callback: callback,
interval: 1 * time.Second, // Default check interval interval: 1 * time.Second, // Default check interval
timeout: 2500 * time.Millisecond, timeout: 2500 * time.Millisecond,
maxAttempts: 8, maxAttempts: 15,
privateKey: privateKey, privateKey: privateKey,
wsClient: wsClient, wsClient: wsClient,
device: device, device: device,

View File

@@ -62,6 +62,7 @@ type Config struct {
Endpoint string Endpoint string
TlsClientCert string // legacy PKCS12 file path TlsClientCert string // legacy PKCS12 file path
UserToken string // optional user token for websocket authentication UserToken string // optional user token for websocket authentication
OrgID string // optional organization ID for websocket authentication
} }
type Client struct { type Client struct {
@@ -131,12 +132,13 @@ func (c *Client) OnAuthError(callback func(statusCode int, message string)) {
} }
// NewClient creates a new websocket client // NewClient creates a new websocket client
func NewClient(ID, secret string, userToken string, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) { func NewClient(ID, secret, userToken, orgId, endpoint string, pingInterval time.Duration, pingTimeout time.Duration, opts ...ClientOption) (*Client, error) {
config := &Config{ config := &Config{
ID: ID, ID: ID,
Secret: secret, Secret: secret,
Endpoint: endpoint, Endpoint: endpoint,
UserToken: userToken, UserToken: userToken,
OrgID: orgId,
} }
client := &Client{ client := &Client{
@@ -321,11 +323,10 @@ func (c *Client) getToken() (string, []ExitNode, error) {
logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable") logger.Debug("TLS certificate verification disabled via SKIP_TLS_VERIFY environment variable")
} }
var tokenData map[string]interface{} tokenData := map[string]interface{}{
tokenData = map[string]interface{}{
"olmId": c.config.ID, "olmId": c.config.ID,
"secret": c.config.Secret, "secret": c.config.Secret,
"orgId": c.config.OrgID,
} }
jsonData, err := json.Marshal(tokenData) jsonData, err := json.Marshal(tokenData)