diff --git a/main.go b/main.go index 5e4e1d9..630e7a1 100644 --- a/main.go +++ b/main.go @@ -235,7 +235,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { OrgID: config.OrgID, OverrideDNS: config.OverrideDNS, EnableUAPI: true, - DisableRelay: true, + DisableRelay: false, // allow it to relay } go olm.StartTunnel(tunnelConfig) } else { diff --git a/olm/olm.go b/olm/olm.go index b1ffb12..0c8a50c 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -33,7 +33,6 @@ var ( connected bool dev *device.Device wgData WgData - holePunchData HolePunchData uapiListener net.Listener tdev tun.Device middleDev *olmDevice.MiddleDevice @@ -48,13 +47,22 @@ var ( globalConfig GlobalConfig globalCtx context.Context stopRegister func() + stopPeerSend func() + updateRegister func(newData interface{}) stopPing chan struct{} peerManager *peers.PeerManager ) -// initSharedBindAndHolepunch creates the shared UDP socket and holepunch manager. +// initTunnelInfo creates the shared UDP socket and holepunch manager. // This is used during initial tunnel setup and when switching organizations. -func initSharedBindAndHolepunch(clientID string) error { +func initTunnelInfo(clientID string) error { + var err error + privateKey, err = wgtypes.GeneratePrivateKey() + if err != nil { + logger.Error("Failed to generate private key: %v", err) + return err + } + sourcePort, err := util.FindAvailableUDPPort(49152, 65535) if err != nil { return fmt.Errorf("failed to find available UDP port: %w", err) @@ -82,7 +90,7 @@ func initSharedBindAndHolepunch(clientID string) error { logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount()) // Create the holepunch manager - holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm") + holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String()) return nil } @@ -249,82 +257,12 @@ func StartTunnel(config TunnelConfig) { // Store the client reference globally olmClient = olm - privateKey, err = wgtypes.GeneratePrivateKey() - if err != nil { - logger.Error("Failed to generate private key: %v", err) - return - } - // Create shared UDP socket and holepunch manager - if err := initSharedBindAndHolepunch(id); err != nil { + if err := initTunnelInfo(id); err != nil { logger.Error("%v", err) return } - olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { - logger.Debug("Received message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &holePunchData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - // Convert HolePunchData.ExitNodes to holepunch.ExitNode slice - exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes)) - for i, node := range holePunchData.ExitNodes { - exitNodes[i] = holepunch.ExitNode{ - Endpoint: node.Endpoint, - PublicKey: node.PublicKey, - } - } - - // Start hole punching using the manager - logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) - if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil { - logger.Warn("Failed to start hole punch: %v", err) - } - }) - - olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { - // THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY - logger.Debug("Received message: %v", msg.Data) - - type LegacyHolePunchData struct { - ServerPubKey string `json:"serverPubKey"` - Endpoint string `json:"endpoint"` - } - - var legacyHolePunchData LegacyHolePunchData - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - // Stop any existing hole punch operations - if holePunchManager != nil { - holePunchManager.Stop() - } - - // Start hole punching for the exit node - logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey) - if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil { - logger.Warn("Failed to start hole punch: %v", err) - } - }) - olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) @@ -338,9 +276,9 @@ func StartTunnel(config TunnelConfig) { stopRegister = nil } - // wait 10 milliseconds to ensure the previous connection is closed - logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed") - time.Sleep(500 * time.Millisecond) + if updateRegister != nil { + updateRegister = nil + } // if there is an existing tunnel then close it if dev != nil { @@ -572,6 +510,11 @@ func StartTunnel(config TunnelConfig) { olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) { logger.Debug("Received add-peer message: %v", msg.Data) + if stopPeerSend != nil { + stopPeerSend() + stopPeerSend = nil + } + jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) @@ -584,6 +527,8 @@ func StartTunnel(config TunnelConfig) { return } + holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it + if err := peerManager.AddPeer(siteConfig, endpoint); err != nil { logger.Error("Failed to add peer: %v", err) return @@ -753,6 +698,59 @@ func StartTunnel(config TunnelConfig) { peerMonitor.HandleFailover(relayData.SiteId, primaryRelay) }) + // Handler for peer handshake - adds exit node to holepunch rotation and notifies server + olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Add exit node to holepunch rotation if we have a holepunch manager + if holePunchManager != nil { + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + PublicKey: handshakeData.ExitNode.PublicKey, + } + + added := holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + // Start holepunching if not already running + if !holePunchManager.IsRunning() { + if err := holePunchManager.Start(); err != nil { + logger.Error("Failed to start holepunch manager: %v", err) + } + } + } + + // Send handshake acknowledgment back to server with retry + stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) + }) + olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) { logger.Info("Received terminate message") apiServer.SetTerminated(true) @@ -779,15 +777,17 @@ func StartTunnel(config TunnelConfig) { publicKey := privateKey.PublicKey() + // delay for 500ms to allow for time for the hp to get processed + time.Sleep(500 * time.Millisecond) + if stopRegister == nil { logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch) - stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "relay": !config.Holepunch, "olmVersion": globalConfig.Version, "orgId": config.OrgID, "userToken": userToken, - // "doNotCreateNewClient": config.DoNotCreateNewClient, }, 1*time.Second) // Invoke onRegistered callback if configured @@ -801,9 +801,28 @@ func StartTunnel(config TunnelConfig) { return nil }) - olm.OnTokenUpdate(func(token string) { + olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) { if holePunchManager != nil { holePunchManager.SetToken(token) + + logger.Debug("Got exit nodes for hole punching: %v", exitNodes) + + // Convert websocket.ExitNode to holepunch.ExitNode + hpExitNodes := make([]holepunch.ExitNode, len(exitNodes)) + for i, node := range exitNodes { + hpExitNodes[i] = holepunch.ExitNode{ + Endpoint: node.Endpoint, + PublicKey: node.PublicKey, + } + } + + logger.Debug("Updated hole punch exit nodes: %v", hpExitNodes) + + // Start hole punching using the manager + logger.Info("Starting hole punch for %d exit nodes", len(exitNodes)) + if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil { + logger.Warn("Failed to start hole punch: %v", err) + } } }) @@ -814,6 +833,7 @@ func StartTunnel(config TunnelConfig) { apiServer.SetRegistered(false) apiServer.ClearPeerStatuses() network.ClearNetworkSettings() + Close() if globalConfig.OnAuthError != nil { @@ -864,6 +884,10 @@ func Close() { stopRegister = nil } + if updateRegister != nil { + updateRegister = nil + } + if peerMonitor != nil { peerMonitor.Close() // Close() also calls Stop() internally peerMonitor = nil @@ -992,7 +1016,7 @@ func SwitchOrg(orgID string) error { Close() // Recreate sharedBind and holepunch manager - needed because Close() releases them - if err := initSharedBindAndHolepunch(olmClient.GetConfig().ID); err != nil { + if err := initTunnelInfo(olmClient.GetConfig().ID); err != nil { return err } @@ -1002,7 +1026,7 @@ func SwitchOrg(orgID string) error { // Trigger re-registration with new orgId logger.Info("Re-registering with new orgId: %s", orgID) publicKey := privateKey.PublicKey() - stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{ + stopRegister, updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{ "publicKey": publicKey.String(), "relay": true, // Default to relay mode for org switch "olmVersion": globalConfig.Version, diff --git a/olm/types.go b/olm/types.go index 39fef25..5f384b7 100644 --- a/olm/types.go +++ b/olm/types.go @@ -12,25 +12,6 @@ type WgData struct { UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses } -type HolePunchMessage struct { - NewtID string `json:"newtId"` -} - -type ExitNode struct { - Endpoint string `json:"endpoint"` - PublicKey string `json:"publicKey"` -} - -type HolePunchData struct { - ExitNodes []ExitNode `json:"exitNodes"` -} - -type EncryptedHolePunchMessage struct { - EphemeralPublicKey string `json:"ephemeralPublicKey"` - Nonce []byte `json:"nonce"` - Ciphertext []byte `json:"ciphertext"` -} - type GlobalConfig struct { // Logging LogLevel string diff --git a/websocket/client.go b/websocket/client.go index 64ffb45..74970a3 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -38,12 +38,18 @@ func IsAuthError(err error) bool { type TokenResponse struct { Data struct { - Token string `json:"token"` + Token string `json:"token"` + ExitNodes []ExitNode `json:"exitNodes"` } `json:"data"` Success bool `json:"success"` Message string `json:"message"` } +type ExitNode struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + type WSMessage struct { Type string `json:"type"` Data interface{} `json:"data"` @@ -71,7 +77,7 @@ type Client struct { pingInterval time.Duration pingTimeout time.Duration onConnect func() error - onTokenUpdate func(token string) + onTokenUpdate func(token string, exitNodes []ExitNode) onAuthError func(statusCode int, message string) // Callback for auth errors writeMux sync.Mutex clientType string // Type of client (e.g., "newt", "olm") @@ -116,7 +122,7 @@ func (c *Client) OnConnect(callback func() error) { c.onConnect = callback } -func (c *Client) OnTokenUpdate(callback func(token string)) { +func (c *Client) OnTokenUpdate(callback func(token string, exitNodes []ExitNode)) { c.onTokenUpdate = callback } @@ -212,13 +218,17 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { return c.conn.WriteJSON(msg) } -func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func(), update func(newData interface{})) { stopChan := make(chan struct{}) + updateChan := make(chan interface{}) + var dataMux sync.Mutex + currentData := data + go func() { count := 0 maxAttempts := 10 - err := c.SendMessage(messageType, data) // Send immediately + err := c.SendMessage(messageType, currentData) // Send immediately if err != nil { logger.Error("Failed to send initial message: %v", err) } @@ -233,19 +243,46 @@ func (c *Client) SendMessageInterval(messageType string, data interface{}, inter logger.Info("SendMessageInterval timed out after %d attempts for message type: %s", maxAttempts, messageType) return } - err = c.SendMessage(messageType, data) + dataMux.Lock() + err = c.SendMessage(messageType, currentData) + dataMux.Unlock() if err != nil { logger.Error("Failed to send message: %v", err) } count++ + case newData := <-updateChan: + dataMux.Lock() + // Merge newData into currentData if both are maps + if currentMap, ok := currentData.(map[string]interface{}); ok { + if newMap, ok := newData.(map[string]interface{}); ok { + // Update or add keys from newData + for key, value := range newMap { + currentMap[key] = value + } + currentData = currentMap + } else { + // If newData is not a map, replace entirely + currentData = newData + } + } else { + // If currentData is not a map, replace entirely + currentData = newData + } + dataMux.Unlock() case <-stopChan: return } } }() return func() { - close(stopChan) - } + close(stopChan) + }, func(newData interface{}) { + select { + case updateChan <- newData: + case <-stopChan: + // Channel is closed, ignore update + } + } } // RegisterHandler registers a handler for a specific message type @@ -255,11 +292,11 @@ func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { c.handlers[messageType] = handler } -func (c *Client) getToken() (string, error) { +func (c *Client) getToken() (string, []ExitNode, error) { // Parse the base URL to ensure we have the correct hostname baseURL, err := url.Parse(c.baseURL) if err != nil { - return "", fmt.Errorf("failed to parse base URL: %w", err) + return "", nil, fmt.Errorf("failed to parse base URL: %w", err) } // Ensure we have the base URL without trailing slashes @@ -271,7 +308,7 @@ func (c *Client) getToken() (string, error) { if c.tlsConfig.ClientCertFile != "" || c.tlsConfig.ClientKeyFile != "" || len(c.tlsConfig.CAFiles) > 0 || c.tlsConfig.PKCS12File != "" { tlsConfig, err = c.setupTLS() if err != nil { - return "", fmt.Errorf("failed to setup TLS configuration: %w", err) + return "", nil, fmt.Errorf("failed to setup TLS configuration: %w", err) } } @@ -293,7 +330,7 @@ func (c *Client) getToken() (string, error) { jsonData, err := json.Marshal(tokenData) if err != nil { - return "", fmt.Errorf("failed to marshal token request data: %w", err) + return "", nil, fmt.Errorf("failed to marshal token request data: %w", err) } // Create a new request @@ -303,7 +340,7 @@ func (c *Client) getToken() (string, error) { bytes.NewBuffer(jsonData), ) if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) + return "", nil, fmt.Errorf("failed to create request: %w", err) } // Set headers @@ -319,7 +356,7 @@ func (c *Client) getToken() (string, error) { } resp, err := client.Do(req) if err != nil { - return "", fmt.Errorf("failed to request new token: %w", err) + return "", nil, fmt.Errorf("failed to request new token: %w", err) } defer resp.Body.Close() @@ -329,33 +366,33 @@ func (c *Client) getToken() (string, error) { // Return AuthError for 401/403 status codes if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { - return "", &AuthError{ + return "", nil, &AuthError{ StatusCode: resp.StatusCode, Message: string(body), } } // For other errors (5xx, network issues, etc.), return regular error - return "", fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) + return "", nil, fmt.Errorf("failed to get token with status code: %d, body: %s", resp.StatusCode, string(body)) } var tokenResp TokenResponse if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { logger.Error("Failed to decode token response.") - return "", fmt.Errorf("failed to decode token response: %w", err) + return "", nil, fmt.Errorf("failed to decode token response: %w", err) } if !tokenResp.Success { - return "", fmt.Errorf("failed to get token: %s", tokenResp.Message) + return "", nil, fmt.Errorf("failed to get token: %s", tokenResp.Message) } if tokenResp.Data.Token == "" { - return "", fmt.Errorf("received empty token from server") + return "", nil, fmt.Errorf("received empty token from server") } logger.Debug("Received token: %s", tokenResp.Data.Token) - return tokenResp.Data.Token, nil + return tokenResp.Data.Token, tokenResp.Data.ExitNodes, nil } func (c *Client) connectWithRetry() { @@ -389,13 +426,13 @@ func (c *Client) connectWithRetry() { func (c *Client) establishConnection() error { // Get token for authentication - token, err := c.getToken() + token, exitNodes, err := c.getToken() if err != nil { return fmt.Errorf("failed to get token: %w", err) } if c.onTokenUpdate != nil { - c.onTokenUpdate(token) + c.onTokenUpdate(token, exitNodes) } // Parse the base URL to determine protocol and hostname