diff --git a/olm/data.go b/olm/data.go index eff46f4..1cd29fa 100644 --- a/olm/data.go +++ b/olm/data.go @@ -135,67 +135,6 @@ func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) { logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId) } -func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { - logger.Debug("Received peer-handshake message: %v", msg.Data) - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Error("Error marshaling handshake data: %v", err) - return - } - - var handshakeData struct { - SiteId int `json:"siteId"` - ExitNode struct { - PublicKey string `json:"publicKey"` - Endpoint string `json:"endpoint"` - RelayPort uint16 `json:"relayPort"` - } `json:"exitNode"` - } - - if err := json.Unmarshal(jsonData, &handshakeData); err != nil { - logger.Error("Error unmarshaling handshake data: %v", err) - return - } - - // Get existing peer from PeerManager - _, exists := o.peerManager.GetPeer(handshakeData.SiteId) - if exists { - logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) - return - } - - relayPort := handshakeData.ExitNode.RelayPort - if relayPort == 0 { - relayPort = 21820 // default relay port - } - - siteId := handshakeData.SiteId - exitNode := holepunch.ExitNode{ - Endpoint: handshakeData.ExitNode.Endpoint, - RelayPort: relayPort, - PublicKey: handshakeData.ExitNode.PublicKey, - SiteIds: []int{siteId}, - } - - added := o.holePunchManager.AddExitNode(exitNode) - if added { - logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) - } else { - logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) - } - - o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt - o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud - - // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second, 10) - - logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) -} - // Handler for syncing peer configuration - reconciles expected state with actual state func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Debug("++++++++++++++++++++++++++++Received sync message: %v", msg.Data) @@ -222,6 +161,9 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { return } + // Sync exit nodes for hole punching + o.syncExitNodes(syncData.ExitNodes) + // Build a map of expected peers from the incoming data expectedPeers := make(map[int]peers.SiteConfig) for _, site := range syncData.Sites { @@ -259,15 +201,21 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { // New peer - add it using the add flow (with holepunch) logger.Info("Sync: Adding new peer for site %d", siteId) - // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it o.holePunchManager.TriggerHolePunch() - // TODO: do we need to send the message to the cloud to add the peer that way? - if err := o.peerManager.AddPeer(expectedSite); err != nil { - logger.Error("Sync: Failed to add peer %d: %v", siteId, err) - } else { - logger.Info("Sync: Successfully added peer for site %d", siteId) - } + // // TODO: do we need to send the message to the cloud to add the peer that way? + // if err := o.peerManager.AddPeer(expectedSite); err != nil { + // logger.Error("Sync: Failed to add peer %d: %v", siteId, err) + // } else { + // logger.Info("Sync: Successfully added peer for site %d", siteId) + // } + + // add the peer via the server + // this is important because newt needs to get triggered as well to add the peer once the hp is complete + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": expectedSite.SiteId, + }, 1*time.Second, 10) + } else { // Existing peer - check if update is needed currentSite := currentPeerMap[siteId] @@ -342,3 +290,58 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers)) } + +// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager +func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) { + if o.holePunchManager == nil { + logger.Warn("Hole punch manager not initialized, skipping exit node sync") + return + } + + // Build a map of expected exit nodes by endpoint + expectedExitNodeMap := make(map[string]SyncExitNode) + for _, exitNode := range expectedExitNodes { + expectedExitNodeMap[exitNode.Endpoint] = exitNode + } + + // Get current exit nodes from hole punch manager + currentExitNodes := o.holePunchManager.GetExitNodes() + currentExitNodeMap := make(map[string]holepunch.ExitNode) + for _, exitNode := range currentExitNodes { + currentExitNodeMap[exitNode.Endpoint] = exitNode + } + + // Find exit nodes to remove (in current but not in expected) + for endpoint := range currentExitNodeMap { + if _, exists := expectedExitNodeMap[endpoint]; !exists { + logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint) + o.holePunchManager.RemoveExitNode(endpoint) + } + } + + // Find exit nodes to add (in expected but not in current) + for endpoint, expectedExitNode := range expectedExitNodeMap { + if _, exists := currentExitNodeMap[endpoint]; !exists { + logger.Info("Sync: Adding new exit node %s", endpoint) + + relayPort := expectedExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + hpExitNode := holepunch.ExitNode{ + Endpoint: expectedExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: expectedExitNode.PublicKey, + SiteIds: expectedExitNode.SiteIds, + } + + if o.holePunchManager.AddExitNode(hpExitNode) { + logger.Info("Sync: Successfully added exit node %s", endpoint) + } + o.holePunchManager.TriggerHolePunch() + } + } + + logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap)) +} diff --git a/olm/peer.go b/olm/peer.go index 9bc842e..56e298d 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -2,7 +2,9 @@ package olm import ( "encoding/json" + "time" + "github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/util" "github.com/fosrl/olm/peers" @@ -193,3 +195,64 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) { o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay) } + +func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { + logger.Debug("Received peer-handshake message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling handshake data: %v", err) + return + } + + var handshakeData struct { + SiteId int `json:"siteId"` + ExitNode struct { + PublicKey string `json:"publicKey"` + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + } `json:"exitNode"` + } + + if err := json.Unmarshal(jsonData, &handshakeData); err != nil { + logger.Error("Error unmarshaling handshake data: %v", err) + return + } + + // Get existing peer from PeerManager + _, exists := o.peerManager.GetPeer(handshakeData.SiteId) + if exists { + logger.Warn("Peer with site ID %d already added", handshakeData.SiteId) + return + } + + relayPort := handshakeData.ExitNode.RelayPort + if relayPort == 0 { + relayPort = 21820 // default relay port + } + + siteId := handshakeData.SiteId + exitNode := holepunch.ExitNode{ + Endpoint: handshakeData.ExitNode.Endpoint, + RelayPort: relayPort, + PublicKey: handshakeData.ExitNode.PublicKey, + SiteIds: []int{siteId}, + } + + added := o.holePunchManager.AddExitNode(exitNode) + if added { + logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint) + } else { + logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint) + } + + o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt + o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud + + // Send handshake acknowledgment back to server with retry + o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + }, 1*time.Second, 10) + + logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) +} diff --git a/olm/types.go b/olm/types.go index 491ed19..2e56ad7 100644 --- a/olm/types.go +++ b/olm/types.go @@ -13,7 +13,15 @@ type WgData struct { } type SyncData struct { - Sites []peers.SiteConfig `json:"sites"` + Sites []peers.SiteConfig `json:"sites"` + ExitNodes []SyncExitNode `json:"exitNodes"` +} + +type SyncExitNode struct { + Endpoint string `json:"endpoint"` + RelayPort uint16 `json:"relayPort"` + PublicKey string `json:"publicKey"` + SiteIds []int `json:"siteIds"` } type OlmConfig struct { diff --git a/websocket/client.go b/websocket/client.go index 8bcbeb3..4a1099e 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -96,6 +96,9 @@ type Client struct { exitNodes []ExitNode // Cached exit nodes from token response tokenMux sync.RWMutex // Protects token and exitNodes forceNewToken bool // Flag to force fetching a new token on next connection + processingMessage bool // Flag to track if a message is currently being processed + processingMux sync.RWMutex // Protects processingMessage + processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete } type ClientOption func(*Client) @@ -222,6 +225,9 @@ func (c *Client) Disconnect() error { c.isDisconnected = true c.setConnected(false) + // Wait for any message currently being processed to complete + c.processingWg.Wait() + if c.conn != nil { c.writeMux.Lock() c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) @@ -651,6 +657,14 @@ func (c *Client) pingMonitor() { if c.isDisconnected || c.conn == nil { return } + // Skip ping if a message is currently being processed + c.processingMux.RLock() + isProcessing := c.processingMessage + c.processingMux.RUnlock() + if isProcessing { + logger.Debug("websocket: Skipping ping, message is being processed") + continue + } // Send application-level ping with config version c.configVersionMux.RLock() configVersion := c.configVersion @@ -753,7 +767,19 @@ func (c *Client) readPumpWithDisconnectDetection() { c.handlersMux.RLock() if handler, ok := c.handlers[msg.Type]; ok { + // Mark that we're processing a message + c.processingMux.Lock() + c.processingMessage = true + c.processingMux.Unlock() + c.processingWg.Add(1) + handler(msg) + + // Mark that we're done processing + c.processingWg.Done() + c.processingMux.Lock() + c.processingMessage = false + c.processingMux.Unlock() } c.handlersMux.RUnlock() }