mirror of
https://github.com/fosrl/olm.git
synced 2026-03-07 03:06:44 +00:00
Working jit with chain ids
This commit is contained in:
15
olm/data.go
15
olm/data.go
@@ -2,6 +2,7 @@ package olm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fosrl/newt/holepunch"
|
"github.com/fosrl/newt/holepunch"
|
||||||
@@ -231,9 +232,17 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
// add the peer via the server
|
// add the peer via the server
|
||||||
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
||||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
chainId := fmt.Sprintf("sync-%d", expectedSite.SiteId)
|
||||||
"siteId": expectedSite.SiteId,
|
o.peerSendMu.Lock()
|
||||||
}, 1*time.Second, 10)
|
if stop, ok := o.stopPeerSends[chainId]; ok {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": expectedSite.SiteId,
|
||||||
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerSends[chainId] = stopFunc
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Existing peer - check if update is needed
|
// Existing peer - check if update is needed
|
||||||
|
|||||||
58
olm/olm.go
58
olm/olm.go
@@ -2,6 +2,8 @@ package olm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -65,8 +67,9 @@ type Olm struct {
|
|||||||
stopRegister func()
|
stopRegister func()
|
||||||
updateRegister func(newData any)
|
updateRegister func(newData any)
|
||||||
|
|
||||||
stopPeerSend func()
|
stopPeerSends map[string]func()
|
||||||
stopPeerInit func()
|
stopPeerInits map[string]func()
|
||||||
|
peerSendMu sync.Mutex
|
||||||
|
|
||||||
// WaitGroup to track tunnel lifecycle
|
// WaitGroup to track tunnel lifecycle
|
||||||
tunnelWg sync.WaitGroup
|
tunnelWg sync.WaitGroup
|
||||||
@@ -117,6 +120,13 @@ func (o *Olm) initTunnelInfo(clientID string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateChainId generates a random chain ID for tracking peer sender lifecycles.
|
||||||
|
func generateChainId() string {
|
||||||
|
b := make([]byte, 8)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return hex.EncodeToString(b)
|
||||||
|
}
|
||||||
|
|
||||||
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
||||||
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
logger.GetLogger().SetLevel(util.ParseLogLevel(config.LogLevel))
|
||||||
|
|
||||||
@@ -167,10 +177,12 @@ func Init(ctx context.Context, config OlmConfig) (*Olm, error) {
|
|||||||
apiServer.SetAgent(config.Agent)
|
apiServer.SetAgent(config.Agent)
|
||||||
|
|
||||||
newOlm := &Olm{
|
newOlm := &Olm{
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
olmCtx: ctx,
|
olmCtx: ctx,
|
||||||
apiServer: apiServer,
|
apiServer: apiServer,
|
||||||
olmConfig: config,
|
olmConfig: config,
|
||||||
|
stopPeerSends: make(map[string]func()),
|
||||||
|
stopPeerInits: make(map[string]func()),
|
||||||
}
|
}
|
||||||
|
|
||||||
newOlm.registerAPICallbacks()
|
newOlm.registerAPICallbacks()
|
||||||
@@ -287,12 +299,17 @@ func (o *Olm) registerAPICallbacks() {
|
|||||||
},
|
},
|
||||||
func(req api.JITConnectionRequest) error {
|
func(req api.JITConnectionRequest) error {
|
||||||
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
|
logger.Info("Processing JIT connect request via API: site=%s resource=%s", req.Site, req.Resource)
|
||||||
|
|
||||||
o.stopPeerInit, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
chainId := generateChainId()
|
||||||
"siteId": req.Site,
|
o.peerSendMu.Lock()
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/init", map[string]interface{}{
|
||||||
|
"siteId": req.Site,
|
||||||
"resourceId": req.Resource,
|
"resourceId": req.Resource,
|
||||||
}, 2*time.Second, 10)
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerInits[chainId] = stopFunc
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -389,6 +406,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
|
|
||||||
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||||
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
|
o.websocket.RegisterHandler("olm/wg/peer/holepunch/site/add", o.handleWgPeerHolepunchAddSite)
|
||||||
|
o.websocket.RegisterHandler("olm/wg/peer/chain/cancel", o.handleCancelChain)
|
||||||
o.websocket.RegisterHandler("olm/sync", o.handleSync)
|
o.websocket.RegisterHandler("olm/sync", o.handleSync)
|
||||||
|
|
||||||
o.websocket.OnConnect(func() error {
|
o.websocket.OnConnect(func() error {
|
||||||
@@ -431,7 +449,7 @@ func (o *Olm) StartTunnel(config TunnelConfig) {
|
|||||||
"userToken": userToken,
|
"userToken": userToken,
|
||||||
"fingerprint": o.fingerprint,
|
"fingerprint": o.fingerprint,
|
||||||
"postures": o.postures,
|
"postures": o.postures,
|
||||||
}, 1*time.Second, 10)
|
}, 2*time.Second, 10)
|
||||||
|
|
||||||
// Invoke onRegistered callback if configured
|
// Invoke onRegistered callback if configured
|
||||||
if o.olmConfig.OnRegistered != nil {
|
if o.olmConfig.OnRegistered != nil {
|
||||||
@@ -528,6 +546,22 @@ func (o *Olm) Close() {
|
|||||||
o.stopRegister = nil
|
o.stopRegister = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop all pending peer init and send senders before closing websocket
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
for _, stop := range o.stopPeerInits {
|
||||||
|
if stop != nil {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o.stopPeerInits = make(map[string]func())
|
||||||
|
for _, stop := range o.stopPeerSends {
|
||||||
|
if stop != nil {
|
||||||
|
stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o.stopPeerSends = make(map[string]func())
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
// send a disconnect message to the cloud to show disconnected
|
// send a disconnect message to the cloud to show disconnected
|
||||||
if o.websocket != nil {
|
if o.websocket != nil {
|
||||||
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
o.websocket.SendMessage("olm/disconnecting", map[string]any{})
|
||||||
|
|||||||
101
olm/peer.go
101
olm/peer.go
@@ -20,31 +20,39 @@ func (o *Olm) handleWgPeerAdd(msg websocket.WSMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if o.stopPeerSend != nil {
|
|
||||||
o.stopPeerSend()
|
|
||||||
o.stopPeerSend = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Error marshaling data: %v", err)
|
logger.Error("Error marshaling data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var siteConfig peers.SiteConfig
|
var siteConfigMsg struct {
|
||||||
if err := json.Unmarshal(jsonData, &siteConfig); err != nil {
|
peers.SiteConfig
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &siteConfigMsg); err != nil {
|
||||||
logger.Error("Error unmarshaling add data: %v", err)
|
logger.Error("Error unmarshaling add data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if siteConfigMsg.ChainId != "" {
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
if stop, ok := o.stopPeerSends[siteConfigMsg.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerSends, siteConfigMsg.ChainId)
|
||||||
|
}
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
_ = o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||||
|
|
||||||
if err := o.peerManager.AddPeer(siteConfig); err != nil {
|
if err := o.peerManager.AddPeer(siteConfigMsg.SiteConfig); err != nil {
|
||||||
logger.Error("Failed to add peer: %v", err)
|
logger.Error("Failed to add peer: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Info("Successfully added peer for site %d", siteConfig.SiteId)
|
|
||||||
|
logger.Info("Successfully added peer for site %d", siteConfigMsg.SiteId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
func (o *Olm) handleWgPeerRemove(msg websocket.WSMessage) {
|
||||||
@@ -230,7 +238,8 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var handshakeData struct {
|
var handshakeData struct {
|
||||||
SiteId int `json:"siteId"`
|
SiteId int `json:"siteId"`
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
ExitNode struct {
|
ExitNode struct {
|
||||||
PublicKey string `json:"publicKey"`
|
PublicKey string `json:"publicKey"`
|
||||||
Endpoint string `json:"endpoint"`
|
Endpoint string `json:"endpoint"`
|
||||||
@@ -242,6 +251,16 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
logger.Error("Error unmarshaling handshake data: %v", err)
|
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Stop the peer init sender for this chain, if any
|
||||||
|
if handshakeData.ChainId != "" {
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
if stop, ok := o.stopPeerInits[handshakeData.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerInits, handshakeData.ChainId)
|
||||||
|
}
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Get existing peer from PeerManager
|
// Get existing peer from PeerManager
|
||||||
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||||
@@ -273,10 +292,64 @@ func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|||||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||||
|
|
||||||
// Send handshake acknowledgment back to server with retry
|
// Send handshake acknowledgment back to server with retry, keyed by chainId
|
||||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
chainId := handshakeData.ChainId
|
||||||
"siteId": handshakeData.SiteId,
|
if chainId == "" {
|
||||||
}, 1*time.Second, 10)
|
chainId = generateChainId()
|
||||||
|
}
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
stopFunc, _ := o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": handshakeData.SiteId,
|
||||||
|
"chainId": chainId,
|
||||||
|
}, 2*time.Second, 10)
|
||||||
|
o.stopPeerSends[chainId] = stopFunc
|
||||||
|
o.peerSendMu.Unlock()
|
||||||
|
|
||||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleCancelChain(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received cancel-chain message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling cancel-chain data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cancelData struct {
|
||||||
|
ChainId string `json:"chainId"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(jsonData, &cancelData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling cancel-chain data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if cancelData.ChainId == "" {
|
||||||
|
logger.Warn("Received cancel-chain message with no chainId")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
o.peerSendMu.Lock()
|
||||||
|
defer o.peerSendMu.Unlock()
|
||||||
|
|
||||||
|
found := false
|
||||||
|
|
||||||
|
if stop, ok := o.stopPeerInits[cancelData.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerInits, cancelData.ChainId)
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if stop, ok := o.stopPeerSends[cancelData.ChainId]; ok {
|
||||||
|
stop()
|
||||||
|
delete(o.stopPeerSends, cancelData.ChainId)
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if found {
|
||||||
|
logger.Info("Cancelled chain %s", cancelData.ChainId)
|
||||||
|
} else {
|
||||||
|
logger.Warn("Cancel-chain: no active sender found for chain %s", cancelData.ChainId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user