From 219df229192eb456dd5bf9ac778c3200d81f6344 Mon Sep 17 00:00:00 2001 From: Owen Date: Mon, 4 Aug 2025 20:43:38 -0700 Subject: [PATCH] Hp to all exit nodes Former-commit-id: b6fb17d8494beb93c5e70f207f759adef9001d2f --- common.go | 46 +++++++++++++++++++--------------------------- main.go | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/common.go b/common.go index db8c155..6bf2bc4 100644 --- a/common.go +++ b/common.go @@ -52,9 +52,13 @@ type HolePunchMessage struct { NewtID string `json:"newtId"` } +type ExitNode struct { + Endpoint string `json:"endpoint"` + PublicKey string `json:"publicKey"` +} + type HolePunchData struct { - ServerPubKey string `json:"serverPubKey"` - Endpoint string `json:"endpoint"` + ExitNodes []ExitNode `json:"exitNodes"` } type EncryptedHolePunchMessage struct { @@ -64,13 +68,11 @@ type EncryptedHolePunchMessage struct { } var ( - peerMonitor *peermonitor.PeerMonitor - stopHolepunch chan struct{} - stopRegister func() - stopPing chan struct{} - olmToken string - gerbilServerPubKey string - holePunchRunning bool + peerMonitor *peermonitor.PeerMonitor + stopHolepunch chan struct{} + stopRegister func() + stopPing chan struct{} + olmToken string ) const ( @@ -226,8 +228,8 @@ func resolveDomain(domain string) (string, error) { return ipAddr, nil } -func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string) error { - if gerbilServerPubKey == "" || olmToken == "" { +func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID string, serverPubKey string) error { + if serverPubKey == "" || olmToken == "" { return nil } @@ -246,7 +248,7 @@ func sendUDPHolePunchWithConn(conn *net.UDPConn, remoteAddr *net.UDPAddr, olmID } // Encrypt the payload using the server's WireGuard public key - encryptedPayload, err := encryptPayload(payloadBytes, gerbilServerPubKey) + encryptedPayload, err := encryptPayload(payloadBytes, serverPubKey) if err != nil { return fmt.Errorf("failed to encrypt payload: %v", err) } @@ -319,19 +321,9 @@ func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) return encryptedMsg, nil } -func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { - // Check if hole punching is already running - if holePunchRunning { - logger.Debug("UDP hole punch already running, skipping new request") - return - } - - // Set the flag to indicate hole punching is running - holePunchRunning = true - defer func() { - holePunchRunning = false - logger.Info("UDP hole punch goroutine ended") - }() +func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) { + logger.Info("Starting UDP hole punch to %s", endpoint) + defer logger.Info("UDP hole punch goroutine ended for %s", endpoint) host, err := resolveDomain(endpoint) if err != nil { @@ -361,7 +353,7 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { defer conn.Close() // Execute once immediately before starting the loop - if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID); err != nil { + if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil { logger.Error("Failed to send UDP hole punch: %v", err) } @@ -374,7 +366,7 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { logger.Info("Stopping UDP holepunch") return case <-ticker.C: - if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID); err != nil { + if err := sendUDPHolePunchWithConn(conn, remoteAddr, olmID, serverPubKey); err != nil { logger.Error("Failed to send UDP hole punch: %v", err) } } diff --git a/main.go b/main.go index 3abfb1a..867d39a 100644 --- a/main.go +++ b/main.go @@ -420,6 +420,44 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { } olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { + // THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY + logger.Debug("Received message: %v", msg.Data) + + type LegacyHolePunchData struct { + ServerPubKey string `json:"serverPubKey"` + Endpoint string `json:"endpoint"` + } + + var legacyHolePunchData LegacyHolePunchData + + jsonData, err := json.Marshal(msg.Data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return + } + + if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return + } + + // Stop any existing hole punch goroutines by closing the current channel + select { + case <-stopHolepunch: + // Channel already closed + default: + close(stopHolepunch) + } + + // Create a new stopHolepunch channel for the new set of goroutines + stopHolepunch = make(chan struct{}) + + // Start hole punching for each exit node + logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey) + go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey) + }) + + olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) jsonData, err := json.Marshal(msg.Data) @@ -433,9 +471,22 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { return } - gerbilServerPubKey = holePunchData.ServerPubKey + // Stop any existing hole punch goroutines by closing the current channel + select { + case <-stopHolepunch: + // Channel already closed + default: + close(stopHolepunch) + } - go keepSendingUDPHolePunch(holePunchData.Endpoint, id, sourcePort) + // Create a new stopHolepunch channel for the new set of goroutines + stopHolepunch = make(chan struct{}) + + // Start hole punching for each exit node + for _, exitNode := range holePunchData.ExitNodes { + logger.Info("Starting hole punch for exit node: %s with public key: %s", exitNode.Endpoint, exitNode.PublicKey) + go keepSendingUDPHolePunch(exitNode.Endpoint, id, sourcePort, exitNode.PublicKey) + } }) // Register handlers for different message types