mirror of
https://github.com/fosrl/olm.git
synced 2026-02-20 20:06:43 +00:00
Handle holepunches better
This commit is contained in:
190
olm/olm.go
190
olm/olm.go
@@ -33,7 +33,6 @@ var (
|
||||
connected bool
|
||||
dev *device.Device
|
||||
wgData WgData
|
||||
holePunchData HolePunchData
|
||||
uapiListener net.Listener
|
||||
tdev tun.Device
|
||||
middleDev *olmDevice.MiddleDevice
|
||||
@@ -48,13 +47,22 @@ var (
|
||||
globalConfig GlobalConfig
|
||||
globalCtx context.Context
|
||||
stopRegister func()
|
||||
stopPeerSend func()
|
||||
updateRegister func(newData interface{})
|
||||
stopPing chan struct{}
|
||||
peerManager *peers.PeerManager
|
||||
)
|
||||
|
||||
// initSharedBindAndHolepunch creates the shared UDP socket and holepunch manager.
|
||||
// initTunnelInfo creates the shared UDP socket and holepunch manager.
|
||||
// This is used during initial tunnel setup and when switching organizations.
|
||||
func initSharedBindAndHolepunch(clientID string) error {
|
||||
func initTunnelInfo(clientID string) error {
|
||||
var err error
|
||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Error("Failed to generate private key: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
sourcePort, err := util.FindAvailableUDPPort(49152, 65535)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find available UDP port: %w", err)
|
||||
@@ -82,7 +90,7 @@ func initSharedBindAndHolepunch(clientID string) error {
|
||||
logger.Info("Created shared UDP socket on port %d (refcount: %d)", sourcePort, sharedBind.GetRefCount())
|
||||
|
||||
// Create the holepunch manager
|
||||
holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm")
|
||||
holePunchManager = holepunch.NewManager(sharedBind, clientID, "olm", privateKey.PublicKey().String())
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -249,82 +257,12 @@ func StartTunnel(config TunnelConfig) {
|
||||
// Store the client reference globally
|
||||
olmClient = olm
|
||||
|
||||
privateKey, err = wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
logger.Error("Failed to generate private key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create shared UDP socket and holepunch manager
|
||||
if err := initSharedBindAndHolepunch(id); err != nil {
|
||||
if err := initTunnelInfo(id); err != nil {
|
||||
logger.Error("%v", err)
|
||||
return
|
||||
}
|
||||
|
||||
olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &holePunchData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Convert HolePunchData.ExitNodes to holepunch.ExitNode slice
|
||||
exitNodes := make([]holepunch.ExitNode, len(holePunchData.ExitNodes))
|
||||
for i, node := range holePunchData.ExitNodes {
|
||||
exitNodes[i] = holepunch.ExitNode{
|
||||
Endpoint: node.Endpoint,
|
||||
PublicKey: node.PublicKey,
|
||||
}
|
||||
}
|
||||
|
||||
// Start hole punching using the manager
|
||||
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||
if err := holePunchManager.StartMultipleExitNodes(exitNodes); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) {
|
||||
// THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
type LegacyHolePunchData struct {
|
||||
ServerPubKey string `json:"serverPubKey"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
var legacyHolePunchData LegacyHolePunchData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Info("Error marshaling data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil {
|
||||
logger.Info("Error unmarshaling target data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Stop any existing hole punch operations
|
||||
if holePunchManager != nil {
|
||||
holePunchManager.Stop()
|
||||
}
|
||||
|
||||
// Start hole punching for the exit node
|
||||
logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey)
|
||||
if err := holePunchManager.StartSingleEndpoint(legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received message: %v", msg.Data)
|
||||
|
||||
@@ -338,9 +276,9 @@ func StartTunnel(config TunnelConfig) {
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
// wait 10 milliseconds to ensure the previous connection is closed
|
||||
logger.Debug("Waiting 500 milliseconds to ensure previous connection is closed")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
if updateRegister != nil {
|
||||
updateRegister = nil
|
||||
}
|
||||
|
||||
// if there is an existing tunnel then close it
|
||||
if dev != nil {
|
||||
@@ -572,6 +510,11 @@ func StartTunnel(config TunnelConfig) {
|
||||
olm.RegisterHandler("olm/wg/peer/add", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received add-peer message: %v", msg.Data)
|
||||
|
||||
if stopPeerSend != nil {
|
||||
stopPeerSend()
|
||||
stopPeerSend = nil
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling data: %v", err)
|
||||
@@ -584,6 +527,8 @@ func StartTunnel(config TunnelConfig) {
|
||||
return
|
||||
}
|
||||
|
||||
holePunchManager.TriggerHolePunch() // Trigger immediate hole punch attempt so that if the peer decides to relay we have already punched close to when we need it
|
||||
|
||||
if err := peerManager.AddPeer(siteConfig, endpoint); err != nil {
|
||||
logger.Error("Failed to add peer: %v", err)
|
||||
return
|
||||
@@ -753,6 +698,59 @@ func StartTunnel(config TunnelConfig) {
|
||||
peerMonitor.HandleFailover(relayData.SiteId, primaryRelay)
|
||||
})
|
||||
|
||||
// Handler for peer handshake - adds exit node to holepunch rotation and notifies server
|
||||
olm.RegisterHandler("olm/wg/peer/holepunch/site/add", func(msg websocket.WSMessage) {
|
||||
logger.Debug("Received peer-handshake message: %v", msg.Data)
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
if err != nil {
|
||||
logger.Error("Error marshaling handshake data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var handshakeData struct {
|
||||
SiteId int `json:"siteId"`
|
||||
ExitNode struct {
|
||||
PublicKey string `json:"publicKey"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
} `json:"exitNode"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(jsonData, &handshakeData); err != nil {
|
||||
logger.Error("Error unmarshaling handshake data: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Add exit node to holepunch rotation if we have a holepunch manager
|
||||
if holePunchManager != nil {
|
||||
exitNode := holepunch.ExitNode{
|
||||
Endpoint: handshakeData.ExitNode.Endpoint,
|
||||
PublicKey: handshakeData.ExitNode.PublicKey,
|
||||
}
|
||||
|
||||
added := holePunchManager.AddExitNode(exitNode)
|
||||
if added {
|
||||
logger.Info("Added exit node %s to holepunch rotation for handshake", exitNode.Endpoint)
|
||||
} else {
|
||||
logger.Debug("Exit node %s already in holepunch rotation", exitNode.Endpoint)
|
||||
}
|
||||
|
||||
// Start holepunching if not already running
|
||||
if !holePunchManager.IsRunning() {
|
||||
if err := holePunchManager.Start(); err != nil {
|
||||
logger.Error("Failed to start holepunch manager: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Send handshake acknowledgment back to server with retry
|
||||
stopPeerSend, _ = olm.SendMessageInterval("olm/wg/server/peer/add", map[string]interface{}{
|
||||
"siteId": handshakeData.SiteId,
|
||||
}, 1*time.Second)
|
||||
|
||||
logger.Info("Initiated handshake for site %d with exit node %s", handshakeData.SiteId, handshakeData.ExitNode.Endpoint)
|
||||
})
|
||||
|
||||
olm.RegisterHandler("olm/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received terminate message")
|
||||
apiServer.SetTerminated(true)
|
||||
@@ -779,15 +777,17 @@ func StartTunnel(config TunnelConfig) {
|
||||
|
||||
publicKey := privateKey.PublicKey()
|
||||
|
||||
// delay for 500ms to allow for time for the hp to get processed
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
if stopRegister == nil {
|
||||
logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !config.Holepunch)
|
||||
stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
stopRegister, updateRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": !config.Holepunch,
|
||||
"olmVersion": globalConfig.Version,
|
||||
"orgId": config.OrgID,
|
||||
"userToken": userToken,
|
||||
// "doNotCreateNewClient": config.DoNotCreateNewClient,
|
||||
}, 1*time.Second)
|
||||
|
||||
// Invoke onRegistered callback if configured
|
||||
@@ -801,9 +801,28 @@ func StartTunnel(config TunnelConfig) {
|
||||
return nil
|
||||
})
|
||||
|
||||
olm.OnTokenUpdate(func(token string) {
|
||||
olm.OnTokenUpdate(func(token string, exitNodes []websocket.ExitNode) {
|
||||
if holePunchManager != nil {
|
||||
holePunchManager.SetToken(token)
|
||||
|
||||
logger.Debug("Got exit nodes for hole punching: %v", exitNodes)
|
||||
|
||||
// Convert websocket.ExitNode to holepunch.ExitNode
|
||||
hpExitNodes := make([]holepunch.ExitNode, len(exitNodes))
|
||||
for i, node := range exitNodes {
|
||||
hpExitNodes[i] = holepunch.ExitNode{
|
||||
Endpoint: node.Endpoint,
|
||||
PublicKey: node.PublicKey,
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug("Updated hole punch exit nodes: %v", hpExitNodes)
|
||||
|
||||
// Start hole punching using the manager
|
||||
logger.Info("Starting hole punch for %d exit nodes", len(exitNodes))
|
||||
if err := holePunchManager.StartMultipleExitNodes(hpExitNodes); err != nil {
|
||||
logger.Warn("Failed to start hole punch: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -814,6 +833,7 @@ func StartTunnel(config TunnelConfig) {
|
||||
apiServer.SetRegistered(false)
|
||||
apiServer.ClearPeerStatuses()
|
||||
network.ClearNetworkSettings()
|
||||
|
||||
Close()
|
||||
|
||||
if globalConfig.OnAuthError != nil {
|
||||
@@ -864,6 +884,10 @@ func Close() {
|
||||
stopRegister = nil
|
||||
}
|
||||
|
||||
if updateRegister != nil {
|
||||
updateRegister = nil
|
||||
}
|
||||
|
||||
if peerMonitor != nil {
|
||||
peerMonitor.Close() // Close() also calls Stop() internally
|
||||
peerMonitor = nil
|
||||
@@ -992,7 +1016,7 @@ func SwitchOrg(orgID string) error {
|
||||
Close()
|
||||
|
||||
// Recreate sharedBind and holepunch manager - needed because Close() releases them
|
||||
if err := initSharedBindAndHolepunch(olmClient.GetConfig().ID); err != nil {
|
||||
if err := initTunnelInfo(olmClient.GetConfig().ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1002,7 +1026,7 @@ func SwitchOrg(orgID string) error {
|
||||
// Trigger re-registration with new orgId
|
||||
logger.Info("Re-registering with new orgId: %s", orgID)
|
||||
publicKey := privateKey.PublicKey()
|
||||
stopRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
stopRegister, updateRegister = olmClient.SendMessageInterval("olm/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"relay": true, // Default to relay mode for org switch
|
||||
"olmVersion": globalConfig.Version,
|
||||
|
||||
19
olm/types.go
19
olm/types.go
@@ -12,25 +12,6 @@ type WgData struct {
|
||||
UtilitySubnet string `json:"utilitySubnet"` // this is for things like the DNS server, and alias addresses
|
||||
}
|
||||
|
||||
type HolePunchMessage struct {
|
||||
NewtID string `json:"newtId"`
|
||||
}
|
||||
|
||||
type ExitNode struct {
|
||||
Endpoint string `json:"endpoint"`
|
||||
PublicKey string `json:"publicKey"`
|
||||
}
|
||||
|
||||
type HolePunchData struct {
|
||||
ExitNodes []ExitNode `json:"exitNodes"`
|
||||
}
|
||||
|
||||
type EncryptedHolePunchMessage struct {
|
||||
EphemeralPublicKey string `json:"ephemeralPublicKey"`
|
||||
Nonce []byte `json:"nonce"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}
|
||||
|
||||
type GlobalConfig struct {
|
||||
// Logging
|
||||
LogLevel string
|
||||
|
||||
Reference in New Issue
Block a user