mirror of
https://github.com/fosrl/olm.git
synced 2026-02-07 21:46:40 +00:00
139
olm/data.go
139
olm/data.go
@@ -135,67 +135,6 @@ func (o *Olm) handleWgPeerUpdateData(msg websocket.WSMessage) {
|
|||||||
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
|
logger.Info("Successfully updated remote subnets and aliases for peer %d", updateSubnetsData.SiteId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
|
||||||
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
|
||||||
|
|
||||||
jsonData, err := json.Marshal(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Error marshaling handshake data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var handshakeData struct {
|
|
||||||
SiteId int `json:"siteId"`
|
|
||||||
ExitNode struct {
|
|
||||||
PublicKey string `json:"publicKey"`
|
|
||||||
Endpoint string `json:"endpoint"`
|
|
||||||
RelayPort uint16 `json:"relayPort"`
|
|
||||||
} `json:"exitNode"`
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
|
||||||
logger.Error("Error unmarshaling handshake data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get existing peer from PeerManager
|
|
||||||
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
|
||||||
if exists {
|
|
||||||
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
relayPort := handshakeData.ExitNode.RelayPort
|
|
||||||
if relayPort == 0 {
|
|
||||||
relayPort = 21820 // default relay port
|
|
||||||
}
|
|
||||||
|
|
||||||
siteId := handshakeData.SiteId
|
|
||||||
exitNode := holepunch.ExitNode{
|
|
||||||
Endpoint: handshakeData.ExitNode.Endpoint,
|
|
||||||
RelayPort: relayPort,
|
|
||||||
PublicKey: handshakeData.ExitNode.PublicKey,
|
|
||||||
SiteIds: []int{siteId},
|
|
||||||
}
|
|
||||||
|
|
||||||
added := o.holePunchManager.AddExitNode(exitNode)
|
|
||||||
if added {
|
|
||||||
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
|
||||||
} else {
|
|
||||||
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
|
||||||
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
|
||||||
|
|
||||||
// Send handshake acknowledgment back to server with retry
|
|
||||||
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
|
||||||
"siteId": handshakeData.SiteId,
|
|
||||||
}, 1*time.Second, 10)
|
|
||||||
|
|
||||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handler for syncing peer configuration - reconciles expected state with actual state
|
// Handler for syncing peer configuration - reconciles expected state with actual state
|
||||||
func (o *Olm) handleSync(msg websocket.WSMessage) {
|
func (o *Olm) handleSync(msg websocket.WSMessage) {
|
||||||
logger.Debug("++++++++++++++++++++++++++++Received sync message: %v", msg.Data)
|
logger.Debug("++++++++++++++++++++++++++++Received sync message: %v", msg.Data)
|
||||||
@@ -222,6 +161,9 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sync exit nodes for hole punching
|
||||||
|
o.syncExitNodes(syncData.ExitNodes)
|
||||||
|
|
||||||
// Build a map of expected peers from the incoming data
|
// Build a map of expected peers from the incoming data
|
||||||
expectedPeers := make(map[int]peers.SiteConfig)
|
expectedPeers := make(map[int]peers.SiteConfig)
|
||||||
for _, site := range syncData.Sites {
|
for _, site := range syncData.Sites {
|
||||||
@@ -259,15 +201,21 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
|||||||
// New peer - add it using the add flow (with holepunch)
|
// New peer - add it using the add flow (with holepunch)
|
||||||
logger.Info("Sync: Adding new peer for site %d", siteId)
|
logger.Info("Sync: Adding new peer for site %d", siteId)
|
||||||
|
|
||||||
// Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
|
||||||
o.holePunchManager.TriggerHolePunch()
|
o.holePunchManager.TriggerHolePunch()
|
||||||
|
|
||||||
// TODO: do we need to send the message to the cloud to add the peer that way?
|
// // TODO: do we need to send the message to the cloud to add the peer that way?
|
||||||
if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
// if err := o.peerManager.AddPeer(expectedSite); err != nil {
|
||||||
logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
|
// logger.Error("Sync: Failed to add peer %d: %v", siteId, err)
|
||||||
} else {
|
// } else {
|
||||||
logger.Info("Sync: Successfully added peer for site %d", siteId)
|
// logger.Info("Sync: Successfully added peer for site %d", siteId)
|
||||||
}
|
// }
|
||||||
|
|
||||||
|
// add the peer via the server
|
||||||
|
// this is important because newt needs to get triggered as well to add the peer once the hp is complete
|
||||||
|
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": expectedSite.SiteId,
|
||||||
|
}, 1*time.Second, 10)
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
// Existing peer - check if update is needed
|
// Existing peer - check if update is needed
|
||||||
currentSite := currentPeerMap[siteId]
|
currentSite := currentPeerMap[siteId]
|
||||||
@@ -342,3 +290,58 @@ func (o *Olm) handleSync(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
|
logger.Info("Sync completed: processed %d expected peers, had %d current peers", len(expectedPeers), len(currentPeers))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// syncExitNodes reconciles the expected exit nodes with the current ones in the hole punch manager
|
||||||
|
func (o *Olm) syncExitNodes(expectedExitNodes []SyncExitNode) {
|
||||||
|
if o.holePunchManager == nil {
|
||||||
|
logger.Warn("Hole punch manager not initialized, skipping exit node sync")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a map of expected exit nodes by endpoint
|
||||||
|
expectedExitNodeMap := make(map[string]SyncExitNode)
|
||||||
|
for _, exitNode := range expectedExitNodes {
|
||||||
|
expectedExitNodeMap[exitNode.Endpoint] = exitNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current exit nodes from hole punch manager
|
||||||
|
currentExitNodes := o.holePunchManager.GetExitNodes()
|
||||||
|
currentExitNodeMap := make(map[string]holepunch.ExitNode)
|
||||||
|
for _, exitNode := range currentExitNodes {
|
||||||
|
currentExitNodeMap[exitNode.Endpoint] = exitNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find exit nodes to remove (in current but not in expected)
|
||||||
|
for endpoint := range currentExitNodeMap {
|
||||||
|
if _, exists := expectedExitNodeMap[endpoint]; !exists {
|
||||||
|
logger.Info("Sync: Removing exit node %s (no longer in expected config)", endpoint)
|
||||||
|
o.holePunchManager.RemoveExitNode(endpoint)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find exit nodes to add (in expected but not in current)
|
||||||
|
for endpoint, expectedExitNode := range expectedExitNodeMap {
|
||||||
|
if _, exists := currentExitNodeMap[endpoint]; !exists {
|
||||||
|
logger.Info("Sync: Adding new exit node %s", endpoint)
|
||||||
|
|
||||||
|
relayPort := expectedExitNode.RelayPort
|
||||||
|
if relayPort == 0 {
|
||||||
|
relayPort = 21820 // default relay port
|
||||||
|
}
|
||||||
|
|
||||||
|
hpExitNode := holepunch.ExitNode{
|
||||||
|
Endpoint: expectedExitNode.Endpoint,
|
||||||
|
RelayPort: relayPort,
|
||||||
|
PublicKey: expectedExitNode.PublicKey,
|
||||||
|
SiteIds: expectedExitNode.SiteIds,
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.holePunchManager.AddExitNode(hpExitNode) {
|
||||||
|
logger.Info("Sync: Successfully added exit node %s", endpoint)
|
||||||
|
}
|
||||||
|
o.holePunchManager.TriggerHolePunch()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Sync exit nodes completed: processed %d expected exit nodes, had %d current exit nodes", len(expectedExitNodeMap), len(currentExitNodeMap))
|
||||||
|
}
|
||||||
|
|||||||
63
olm/peer.go
63
olm/peer.go
@@ -2,7 +2,9 @@ package olm
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/newt/holepunch"
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/util"
|
"github.com/fosrl/newt/util"
|
||||||
"github.com/fosrl/olm/peers"
|
"github.com/fosrl/olm/peers"
|
||||||
@@ -193,3 +195,64 @@ func (o *Olm) handleWgPeerUnrelay(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
|
o.peerManager.UnRelayPeer(relayData.SiteId, primaryRelay)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (o *Olm) handleWgPeerHolepunchAddSite(msg websocket.WSMessage) {
|
||||||
|
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(msg.Data)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("Error marshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var handshakeData struct {
|
||||||
|
SiteId int `json:"siteId"`
|
||||||
|
ExitNode struct {
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
RelayPort uint16 `json:"relayPort"`
|
||||||
|
} `json:"exitNode"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
||||||
|
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get existing peer from PeerManager
|
||||||
|
_, exists := o.peerManager.GetPeer(handshakeData.SiteId)
|
||||||
|
if exists {
|
||||||
|
logger.Warn("Peer with site ID %d already added", handshakeData.SiteId)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
relayPort := handshakeData.ExitNode.RelayPort
|
||||||
|
if relayPort == 0 {
|
||||||
|
relayPort = 21820 // default relay port
|
||||||
|
}
|
||||||
|
|
||||||
|
siteId := handshakeData.SiteId
|
||||||
|
exitNode := holepunch.ExitNode{
|
||||||
|
Endpoint: handshakeData.ExitNode.Endpoint,
|
||||||
|
RelayPort: relayPort,
|
||||||
|
PublicKey: handshakeData.ExitNode.PublicKey,
|
||||||
|
SiteIds: []int{siteId},
|
||||||
|
}
|
||||||
|
|
||||||
|
added := o.holePunchManager.AddExitNode(exitNode)
|
||||||
|
if added {
|
||||||
|
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
||||||
|
} else {
|
||||||
|
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
o.holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt
|
||||||
|
o.holePunchManager.ResetServerHolepunchInterval() // start sending immediately again so we fill in the endpoint on the cloud
|
||||||
|
|
||||||
|
// Send handshake acknowledgment back to server with retry
|
||||||
|
o.stopPeerSend, _ = o.websocket.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||||
|
"siteId": handshakeData.SiteId,
|
||||||
|
}, 1*time.Second, 10)
|
||||||
|
|
||||||
|
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||||
|
}
|
||||||
|
|||||||
10
olm/types.go
10
olm/types.go
@@ -13,7 +13,15 @@ type WgData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SyncData struct {
|
type SyncData struct {
|
||||||
Sites []peers.SiteConfig `json:"sites"`
|
Sites []peers.SiteConfig `json:"sites"`
|
||||||
|
ExitNodes []SyncExitNode `json:"exitNodes"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type SyncExitNode struct {
|
||||||
|
Endpoint string `json:"endpoint"`
|
||||||
|
RelayPort uint16 `json:"relayPort"`
|
||||||
|
PublicKey string `json:"publicKey"`
|
||||||
|
SiteIds []int `json:"siteIds"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OlmConfig struct {
|
type OlmConfig struct {
|
||||||
|
|||||||
@@ -96,6 +96,9 @@ type Client struct {
|
|||||||
exitNodes []ExitNode // Cached exit nodes from token response
|
exitNodes []ExitNode // Cached exit nodes from token response
|
||||||
tokenMux sync.RWMutex // Protects token and exitNodes
|
tokenMux sync.RWMutex // Protects token and exitNodes
|
||||||
forceNewToken bool // Flag to force fetching a new token on next connection
|
forceNewToken bool // Flag to force fetching a new token on next connection
|
||||||
|
processingMessage bool // Flag to track if a message is currently being processed
|
||||||
|
processingMux sync.RWMutex // Protects processingMessage
|
||||||
|
processingWg sync.WaitGroup // WaitGroup to wait for message processing to complete
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientOption func(*Client)
|
type ClientOption func(*Client)
|
||||||
@@ -222,6 +225,9 @@ func (c *Client) Disconnect() error {
|
|||||||
c.isDisconnected = true
|
c.isDisconnected = true
|
||||||
c.setConnected(false)
|
c.setConnected(false)
|
||||||
|
|
||||||
|
// Wait for any message currently being processed to complete
|
||||||
|
c.processingWg.Wait()
|
||||||
|
|
||||||
if c.conn != nil {
|
if c.conn != nil {
|
||||||
c.writeMux.Lock()
|
c.writeMux.Lock()
|
||||||
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
||||||
@@ -651,6 +657,14 @@ func (c *Client) pingMonitor() {
|
|||||||
if c.isDisconnected || c.conn == nil {
|
if c.isDisconnected || c.conn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Skip ping if a message is currently being processed
|
||||||
|
c.processingMux.RLock()
|
||||||
|
isProcessing := c.processingMessage
|
||||||
|
c.processingMux.RUnlock()
|
||||||
|
if isProcessing {
|
||||||
|
logger.Debug("websocket: Skipping ping, message is being processed")
|
||||||
|
continue
|
||||||
|
}
|
||||||
// Send application-level ping with config version
|
// Send application-level ping with config version
|
||||||
c.configVersionMux.RLock()
|
c.configVersionMux.RLock()
|
||||||
configVersion := c.configVersion
|
configVersion := c.configVersion
|
||||||
@@ -753,7 +767,19 @@ func (c *Client) readPumpWithDisconnectDetection() {
|
|||||||
|
|
||||||
c.handlersMux.RLock()
|
c.handlersMux.RLock()
|
||||||
if handler, ok := c.handlers[msg.Type]; ok {
|
if handler, ok := c.handlers[msg.Type]; ok {
|
||||||
|
// Mark that we're processing a message
|
||||||
|
c.processingMux.Lock()
|
||||||
|
c.processingMessage = true
|
||||||
|
c.processingMux.Unlock()
|
||||||
|
c.processingWg.Add(1)
|
||||||
|
|
||||||
handler(msg)
|
handler(msg)
|
||||||
|
|
||||||
|
// Mark that we're done processing
|
||||||
|
c.processingWg.Done()
|
||||||
|
c.processingMux.Lock()
|
||||||
|
c.processingMessage = false
|
||||||
|
c.processingMux.Unlock()
|
||||||
}
|
}
|
||||||
c.handlersMux.RUnlock()
|
c.handlersMux.RUnlock()
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user