diff --git a/olm/data.go b/olm/data.go index 015931b..d0e6d5b 100644 --- a/olm/data.go +++ b/olm/data.go @@ -2,6 +2,7 @@ package olm import ( "encoding/json" + "fmt" "time" "github.com/fosrl/newt/holepunch" @@ -231,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) { // add the peer via the server // this is important because newt needs to get triggered as well to add the peer once the hp is complete - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": expectedSite.SiteId, - }, 1*time.Second, 10) + chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId) + o.peerSendMu.Lock() + if stop, ok := o.stopPeerSends[chainId]; ok { + stop() + } + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": expectedSite.SiteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerSends[chainId] = stopFunc + o.peerSendMu.Unlock() } else { // Existing peer - check if update is needed diff --git a/olm/olm.go b/olm/olm.go index fa32ebd..b2843d2 100644 --- a/olm/olm.go +++ b/olm/olm.go @@ -2,6 +2,8 @@ package olm import ( "context" + "crypto/rand" + "encoding/hex" "fmt" "net" "net/http" @@ -65,8 +67,9 @@ type Olm struct { stopRegister func() updateRegister func(newData any) - stopPeerSend func() - stopPeerInit func() + stopPeerSends map[string]func() + stopPeerInits map[string]func() + peerSendMu sync.Mutex // WaitGroup to track tunnel lifecycle tunnelWg sync.WaitGroup @@ -117,6 +120,13 @@ func (o *Olm) initTunnelInfo(clientID string) error { return nil } +// generateChainId generates a random chain ID for tracking peer sender lifecycles. +func generateChainId() string { + b := make([]byte, 8) + _, _ = rand.Read(b) + return hex.EncodeToString(b) +} + func Init(ctx context.Context, config OlmConfig) (*Olm, error) { logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) @@ -167,10 +177,12 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) { apiServer.SetAgent(config.Agent) newOlm := &Olm{ - logFile: logFile, - olmCtx: ctx, - apiServer: apiServer, - olmConfig: config, + logFile: logFile, + olmCtx: ctx, + apiServer: apiServer, + olmConfig: config, + stopPeerSends: make(map[string]func()), + stopPeerInits: make(map[string]func()), } newOlm.registerAPICallbacks() @@ -287,12 +299,17 @@ func (o *Olm) registerAPICallbacks() { }, func(req api.JITConnectionRequest) error { logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource) - - o.stopPeerInit, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ - "siteId": req.Site, + + chainId := generateChainId() + o.peerSendMu.Lock() + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ + "siteId": req.Site, "resourceId": req.Resource, - }, 2*time.Second, 10) - + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerInits[chainId] = stopFunc + o.peerSendMu.Unlock() + return nil }, ) @@ -389,6 +406,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { // Handler for peer handshake - adds exit node to holepunch rotation and notifies server o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) + o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain) o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.OnConnect(func() error { @@ -431,7 +449,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) { "userToken": userToken, "fingerprint": o.fingerprint, "postures": o.postures, - }, 1*time.Second, 10) + }, 2*time.Second, 10) // Invoke onRegistered callback if configured if o.olmConfig.OnRegistered != nil { @@ -528,6 +546,22 @@ func (o *Olm) Close() { o.stopRegister = nil } + // Stop all pending peer init and send senders before closing websocket + o.peerSendMu.Lock() + for _, stop := range o.stopPeerInits { + if stop != nil { + stop() + } + } + o.stopPeerInits = make(map[string]func()) + for _, stop := range o.stopPeerSends { + if stop != nil { + stop() + } + } + o.stopPeerSends = make(map[string]func()) + o.peerSendMu.Unlock() + // send a disconnect message to the cloud to show disconnected if o.websocket != nil { o.websocket.SendMessage("olm/disconnecting", map[string]any{}) diff --git a/olm/peer.go b/olm/peer.go index 8007272..1937934 100644 --- a/olm/peer.go +++ b/olm/peer.go @@ -20,31 +20,39 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) { return } - if o.stopPeerSend != nil { - o.stopPeerSend() - o.stopPeerSend = nil - } - jsonData, err := json.Marshal(msg.Data) if err != nil { logger.Error("Error marshaling data: %v", err) return } - var siteConfig peers.SiteConfig - if err := json.Unmarshal(jsonData, &siteConfig); err != nil { + var siteConfigMsg struct { + peers.SiteConfig + ChainId string `json:"chainId"` + } + if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil { logger.Error("Error unmarshaling add data: %v", err) return } + if siteConfigMsg.ChainId != "" { + o.peerSendMu.Lock() + if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok { + stop() + delete(o.stopPeerSends, siteConfigMsg.ChainId) + } + o.peerSendMu.Unlock() + } + _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it - if err := o.peerManager.AddPeer(siteConfig); err != nil { + if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil { logger.Error("Failed to add peer: %v", err) return } - logger.Info("Successfully added peer for site %d", siteConfig.SiteId) + + logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId) } func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { @@ -230,7 +238,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { } var handshakeData struct { - SiteId int `json:"siteId"` + SiteId int `json:"siteId"` + ChainId string `json:"chainId"` ExitNode struct { PublicKey string `json:"publicKey"` Endpoint string `json:"endpoint"` @@ -242,6 +251,16 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { logger.Error("Error unmarshaling handshake data: %v", err) return } + + // Stop the peer init sender for this chain, if any + if handshakeData.ChainId != "" { + o.peerSendMu.Lock() + if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok { + stop() + delete(o.stopPeerInits, handshakeData.ChainId) + } + o.peerSendMu.Unlock() + } // Get existing peer from PeerManager _, exists := o.peerManager.GetPeer(handshakeData.SiteId) @@ -273,10 +292,64 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) { o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud - // Send handshake acknowledgment back to server with retry - o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ - "siteId": handshakeData.SiteId, - }, 1*time.Second, 10) + // Send handshake acknowledgment back to server with retry, keyed by chainId + chainId := handshakeData.ChainId + if chainId == "" { + chainId = generateChainId() + } + o.peerSendMu.Lock() + stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ + "siteId": handshakeData.SiteId, + "chainId": chainId, + }, 2*time.Second, 10) + o.stopPeerSends[chainId] = stopFunc + o.peerSendMu.Unlock() logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) } + +func (o *Olm) handleCancelChain(msg websocket.WSMessage) { + logger.Debug("Received cancel-chain message: %v", msg.Data) + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Error("Error marshaling cancel-chain data: %v", err) + return + } + + var cancelData struct { + ChainId string `json:"chainId"` + } + if err := json.Unmarshal(jsonData, &cancelData); err != nil { + logger.Error("Error unmarshaling cancel-chain data: %v", err) + return + } + + if cancelData.ChainId == "" { + logger.Warn("Received cancel-chain message with no chainId") + return + } + + o.peerSendMu.Lock() + defer o.peerSendMu.Unlock() + + found := false + + if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok { + stop() + delete(o.stopPeerInits, cancelData.ChainId) + found = true + } + + if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok { + stop() + delete(o.stopPeerSends, cancelData.ChainId) + found = true + } + + if found { + logger.Info("Cancelled chain %s", cancelData.ChainId) + } else { + logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId) + } +}