Working jit with chain ids

This commit is contained in:
Owen
2026-03-04 17:51:48 -08:00
parent e7507e0837
commit 051c0fdfd8
3 changed files with 145 additions and 29 deletions

View File

@@ -2,6 +2,7 @@ package olm
import ( import (
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/fosrl/newt/holepunch" "github.com/fosrl/newt/holepunch"
@@ -231,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
// add the peer via the server // add the peer via the server
// this is important because newt needs to get triggered as well to add the peer once the hp is complete // this is important because newt needs to get triggered as well to add the peer once the hp is complete
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
"siteId": expectedSite.SiteId, o.peerSendMu.Lock()
}, 1*time.Second, 10) if stop, ok := o.stopPeerSends[chainId]; ok {
stop()
}
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": expectedSite.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
} else { } else {
// Existing peer - check if update is needed // Existing peer - check if update is needed

View File

@@ -2,6 +2,8 @@ package olm
import ( import (
"context" "context"
"crypto/rand"
"encoding/hex"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -65,8 +67,9 @@ type Olm struct {
stopRegister func() stopRegister func()
updateRegister func(newData any) updateRegister func(newData any)
stopPeerSend func() stopPeerSends map[string]func()
stopPeerInit func() stopPeerInits map[string]func()
peerSendMu sync.Mutex
// WaitGroup to track tunnel lifecycle // WaitGroup to track tunnel lifecycle
tunnelWg sync.WaitGroup tunnelWg sync.WaitGroup
@@ -117,6 +120,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
return nil return nil
} }
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
func generateChainId() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return hex.EncodeToString(b)
}
func Init(ctx context.Context, config OlmConfig) (*Olm, error) { func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel)) logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
@@ -167,10 +177,12 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
apiServer.SetAgent(config.Agent) apiServer.SetAgent(config.Agent)
newOlm := &Olm{ newOlm := &Olm{
logFile: logFile, logFile: logFile,
olmCtx: ctx, olmCtx: ctx,
apiServer: apiServer, apiServer: apiServer,
olmConfig: config, olmConfig: config,
stopPeerSends: make(map[string]func()),
stopPeerInits: make(map[string]func()),
} }
newOlm.registerAPICallbacks() newOlm.registerAPICallbacks()
@@ -287,12 +299,17 @@ func (o *Olm) registerAPICallbacks() {
}, },
func(req api.JITConnectionRequest) error { func(req api.JITConnectionRequest) error {
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource) logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
o.stopPeerInit, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{ chainId := generateChainId()
"siteId": req.Site, o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
"siteId": req.Site,
"resourceId": req.Resource, "resourceId": req.Resource,
}, 2*time.Second, 10) "chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerInits[chainId] = stopFunc
o.peerSendMu.Unlock()
return nil return nil
}, },
) )
@@ -389,6 +406,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server // Handler for peer handshake - adds exit node to holepunch rotation and notifies server
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite) o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
o.websocket.RegisterHandler("olm/sync", o.handleSync) o.websocket.RegisterHandler("olm/sync", o.handleSync)
o.websocket.OnConnect(func() error { o.websocket.OnConnect(func() error {
@@ -431,7 +449,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
"userToken": userToken, "userToken": userToken,
"fingerprint": o.fingerprint, "fingerprint": o.fingerprint,
"postures": o.postures, "postures": o.postures,
}, 1*time.Second, 10) }, 2*time.Second, 10)
// Invoke onRegistered callback if configured // Invoke onRegistered callback if configured
if o.olmConfig.OnRegistered != nil { if o.olmConfig.OnRegistered != nil {
@@ -528,6 +546,22 @@ func (o *Olm) Close() {
o.stopRegister = nil o.stopRegister = nil
} }
// Stop all pending peer init and send senders before closing websocket
o.peerSendMu.Lock()
for _, stop := range o.stopPeerInits {
if stop != nil {
stop()
}
}
o.stopPeerInits = make(map[string]func())
for _, stop := range o.stopPeerSends {
if stop != nil {
stop()
}
}
o.stopPeerSends = make(map[string]func())
o.peerSendMu.Unlock()
// send a disconnect message to the cloud to show disconnected // send a disconnect message to the cloud to show disconnected
if o.websocket != nil { if o.websocket != nil {
o.websocket.SendMessage("olm/disconnecting", map[string]any{}) o.websocket.SendMessage("olm/disconnecting", map[string]any{})

View File

@@ -20,31 +20,39 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
return return
} }
if o.stopPeerSend != nil {
o.stopPeerSend()
o.stopPeerSend = nil
}
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
if err != nil { if err != nil {
logger.Error("Error marshaling data: %v", err) logger.Error("Error marshaling data: %v", err)
return return
} }
var siteConfig peers.SiteConfig var siteConfigMsg struct {
if err := json.Unmarshal(jsonData, &siteConfig); err != nil { peers.SiteConfig
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
logger.Error("Error unmarshaling add data: %v", err) logger.Error("Error unmarshaling add data: %v", err)
return return
} }
if siteConfigMsg.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
stop()
delete(o.stopPeerSends, siteConfigMsg.ChainId)
}
o.peerSendMu.Unlock()
}
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it _ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
if err := o.peerManager.AddPeer(siteConfig); err != nil { if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
logger.Error("Failed to add peer: %v", err) logger.Error("Failed to add peer: %v", err)
return return
} }
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
} }
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) { func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
@@ -230,7 +238,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
} }
var handshakeData struct { var handshakeData struct {
SiteId int `json:"siteId"` SiteId int `json:"siteId"`
ChainId string `json:"chainId"`
ExitNode struct { ExitNode struct {
PublicKey string `json:"publicKey"` PublicKey string `json:"publicKey"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
@@ -242,6 +251,16 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
logger.Error("Error unmarshaling handshake data: %v", err) logger.Error("Error unmarshaling handshake data: %v", err)
return return
} }
// Stop the peer init sender for this chain, if any
if handshakeData.ChainId != "" {
o.peerSendMu.Lock()
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
stop()
delete(o.stopPeerInits, handshakeData.ChainId)
}
o.peerSendMu.Unlock()
}
// Get existing peer from PeerManager // Get existing peer from PeerManager
_, exists := o.peerManager.GetPeer(handshakeData.SiteId) _, exists := o.peerManager.GetPeer(handshakeData.SiteId)
@@ -273,10 +292,64 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
// Send handshake acknowledgment back to server with retry // Send handshake acknowledgment back to server with retry, keyed by chainId
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{ chainId := handshakeData.ChainId
"siteId": handshakeData.SiteId, if chainId == "" {
}, 1*time.Second, 10) chainId = generateChainId()
}
o.peerSendMu.Lock()
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
"siteId": handshakeData.SiteId,
"chainId": chainId,
}, 2*time.Second, 10)
o.stopPeerSends[chainId] = stopFunc
o.peerSendMu.Unlock()
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint) logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
} }
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
logger.Debug("Received cancel-chain message: %v", msg.Data)
jsonData, err := json.Marshal(msg.Data)
if err != nil {
logger.Error("Error marshaling cancel-chain data: %v", err)
return
}
var cancelData struct {
ChainId string `json:"chainId"`
}
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
logger.Error("Error unmarshaling cancel-chain data: %v", err)
return
}
if cancelData.ChainId == "" {
logger.Warn("Received cancel-chain message with no chainId")
return
}
o.peerSendMu.Lock()
defer o.peerSendMu.Unlock()
found := false
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerInits, cancelData.ChainId)
found = true
}
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
stop()
delete(o.stopPeerSends, cancelData.ChainId)
found = true
}
if found {
logger.Info("Cancelled chain %s", cancelData.ChainId)
} else {
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
}
}