diff --git a/.github/workflows/cicd.yml b/.github/workflows/cicd.yml index 37063b5..5dee76a 100644 --- a/.github/workflows/cicd.yml +++ b/.github/workflows/cicd.yml @@ -23,15 +23,15 @@ jobs: with: go-version: 1.24 - # - name: Update version in main.go - # run: | - # TAG=${{ env.TAG }} - # if [ -f main.go ]; then - # sed -i 's/Olm version replaceme/Olm version '"$TAG"'/' main.go - # echo "Updated main.go with version $TAG" - # else - # echo "main.go not found" - # fi + - name: Update version in main.go + run: | + TAG=${{ env.TAG }} + if [ -f main.go ]; then + sed -i 's/version_replaceme/'"$TAG"'/' main.go + echo "Updated main.go with version $TAG" + else + echo "main.go not found" + fi - name: Build binaries run: | diff --git a/common.go b/common.go index 6bf2bc4..df01b33 100644 --- a/common.go +++ b/common.go @@ -68,11 +68,12 @@ type EncryptedHolePunchMessage struct { } var ( - peerMonitor *peermonitor.PeerMonitor - stopHolepunch chan struct{} - stopRegister func() - stopPing chan struct{} - olmToken string + peerMonitor *peermonitor.PeerMonitor + stopHolepunch chan struct{} + stopRegister func() + stopPing chan struct{} + olmToken string + holePunchRunning bool ) const ( @@ -321,7 +322,117 @@ func encryptPayload(payload []byte, serverPublicKey string) (interface{}, error) return encryptedMsg, nil } +func keepSendingUDPHolePunchToMultipleExitNodes(exitNodes []ExitNode, olmID string, sourcePort uint16) { + if len(exitNodes) == 0 { + logger.Warn("No exit nodes provided for hole punching") + return + } + + // Check if hole punching is already running + if holePunchRunning { + logger.Debug("UDP hole punch already running, skipping new request") + return + } + + // Set the flag to indicate hole punching is running + holePunchRunning = true + defer func() { + holePunchRunning = false + logger.Info("UDP hole punch goroutine ended") + }() + + logger.Info("Starting UDP hole punch to %d exit nodes", len(exitNodes)) + defer logger.Info("UDP hole punch goroutine ended for all exit nodes") + + // Create the UDP connection once and reuse it for all exit nodes + localAddr := &net.UDPAddr{ + Port: int(sourcePort), + IP: net.IPv4zero, + } + + conn, err := net.ListenUDP("udp", localAddr) + if err != nil { + logger.Error("Failed to bind UDP socket: %v", err) + return + } + defer conn.Close() + + // Resolve all endpoints upfront + type resolvedExitNode struct { + remoteAddr *net.UDPAddr + publicKey string + endpointName string + } + + var resolvedNodes []resolvedExitNode + for _, exitNode := range exitNodes { + host, err := resolveDomain(exitNode.Endpoint) + if err != nil { + logger.Error("Failed to resolve endpoint %s: %v", exitNode.Endpoint, err) + continue + } + + serverAddr := host + ":21820" + remoteAddr, err := net.ResolveUDPAddr("udp", serverAddr) + if err != nil { + logger.Error("Failed to resolve UDP address for %s: %v", exitNode.Endpoint, err) + continue + } + + resolvedNodes = append(resolvedNodes, resolvedExitNode{ + remoteAddr: remoteAddr, + publicKey: exitNode.PublicKey, + endpointName: exitNode.Endpoint, + }) + logger.Info("Resolved exit node: %s -> %s", exitNode.Endpoint, remoteAddr.String()) + } + + if len(resolvedNodes) == 0 { + logger.Error("No exit nodes could be resolved") + return + } + + // Send initial hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil { + logger.Error("Failed to send initial UDP hole punch to %s: %v", node.endpointName, err) + } + } + + ticker := time.NewTicker(250 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-stopHolepunch: + logger.Info("Stopping UDP holepunch for all exit nodes") + return + case <-ticker.C: + // Send hole punch to all exit nodes + for _, node := range resolvedNodes { + if err := sendUDPHolePunchWithConn(conn, node.remoteAddr, olmID, node.publicKey); err != nil { + logger.Error("Failed to send UDP hole punch to %s: %v", node.endpointName, err) + } + } + } + } +} + func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16, serverPubKey string) { + + // Check if hole punching is already running + if holePunchRunning { + logger.Debug("UDP hole punch already running, skipping new request") + return + } + + // Set the flag to indicate hole punching is running + holePunchRunning = true + defer func() { + holePunchRunning = false + logger.Info("UDP hole punch goroutine ended") + }() + logger.Info("Starting UDP hole punch to %s", endpoint) defer logger.Info("UDP hole punch goroutine ended for %s", endpoint) diff --git a/main.go b/main.go index 867d39a..b883b69 100644 --- a/main.go +++ b/main.go @@ -232,6 +232,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests") serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)") + version := serviceFlags.Bool("version", false, "Print the version") + // Parse the service arguments if err := serviceFlags.Parse(args); err != nil { fmt.Printf("Error parsing service arguments: %v\n", err) @@ -272,6 +274,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { loggerLevel := parseLogLevel(logLevel) logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + olmVersion := "version_replaceme" + if *version { + fmt.Println("Olm version " + olmVersion) + os.Exit(0) + } else { + logger.Info("Olm version " + olmVersion) + } + // Log startup information logger.Debug("Olm service starting...") logger.Debug("Parameters: endpoint='%s', id='%s', secret='%s'", endpoint, id, secret) @@ -419,44 +429,6 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { os.Exit(1) } - olm.RegisterHandler("olm/wg/holepunch", func(msg websocket.WSMessage) { - // THIS ENDPOINT IS FOR BACKWARD COMPATIBILITY - logger.Debug("Received message: %v", msg.Data) - - type LegacyHolePunchData struct { - ServerPubKey string `json:"serverPubKey"` - Endpoint string `json:"endpoint"` - } - - var legacyHolePunchData LegacyHolePunchData - - jsonData, err := json.Marshal(msg.Data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return - } - - if err := json.Unmarshal(jsonData, &legacyHolePunchData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return - } - - // Stop any existing hole punch goroutines by closing the current channel - select { - case <-stopHolepunch: - // Channel already closed - default: - close(stopHolepunch) - } - - // Create a new stopHolepunch channel for the new set of goroutines - stopHolepunch = make(chan struct{}) - - // Start hole punching for each exit node - logger.Info("Starting hole punch for exit node: %s with public key: %s", legacyHolePunchData.Endpoint, legacyHolePunchData.ServerPubKey) - go keepSendingUDPHolePunch(legacyHolePunchData.Endpoint, id, sourcePort, legacyHolePunchData.ServerPubKey) - }) - olm.RegisterHandler("olm/wg/holepunch/all", func(msg websocket.WSMessage) { logger.Debug("Received message: %v", msg.Data) @@ -471,22 +443,12 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { return } - // Stop any existing hole punch goroutines by closing the current channel - select { - case <-stopHolepunch: - // Channel already closed - default: - close(stopHolepunch) - } - // Create a new stopHolepunch channel for the new set of goroutines stopHolepunch = make(chan struct{}) - // Start hole punching for each exit node - for _, exitNode := range holePunchData.ExitNodes { - logger.Info("Starting hole punch for exit node: %s with public key: %s", exitNode.Endpoint, exitNode.PublicKey) - go keepSendingUDPHolePunch(exitNode.Endpoint, id, sourcePort, exitNode.PublicKey) - } + // Start a single hole punch goroutine for all exit nodes + logger.Info("Starting hole punch for %d exit nodes", len(holePunchData.ExitNodes)) + go keepSendingUDPHolePunchToMultipleExitNodes(holePunchData.ExitNodes, id, sourcePort) }) // Register handlers for different message types