From 848ac6b0c4706d8b95fda684bf5b81b6755ef2d5 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 24 Jul 2025 14:44:12 -0700 Subject: [PATCH] Holepunch but relay by default Former-commit-id: 5302f9da34ff58e87422efddc1d330dd9e6f1e6d --- common.go | 31 +----------------------- main.go | 30 ++++++++++++++++-------- peermonitor/peermonitor.go | 48 +++++++++++++++++++++----------------- service_windows.go | 6 ++--- 4 files changed, 51 insertions(+), 64 deletions(-) diff --git a/common.go b/common.go index 07f8fb8..0274395 100644 --- a/common.go +++ b/common.go @@ -65,7 +65,7 @@ type EncryptedHolePunchMessage struct { var ( peerMonitor *peermonitor.PeerMonitor stopHolepunch chan struct{} - stopRegister chan struct{} + stopRegister func() stopPing chan struct{} olmToken string gerbilServerPubKey string @@ -378,35 +378,6 @@ func keepSendingUDPHolePunch(endpoint string, olmID string, sourcePort uint16) { } } -func sendRegistration(olm *websocket.Client, publicKey string) error { - err := olm.SendMessage("olm/wg/register", map[string]interface{}{ - "publicKey": publicKey, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return err - } - logger.Info("Sent registration message") - return nil -} - -func keepSendingRegistration(olm *websocket.Client, publicKey string) { - ticker := time.NewTicker(1 * time.Second) - defer ticker.Stop() - - for { - select { - case <-stopRegister: - logger.Info("Stopping registration messages") - return - case <-ticker.C: - if err := sendRegistration(olm, publicKey); err != nil { - logger.Error("Failed to send periodic registration: %v", err) - } - } - } -} - func FindAvailableUDPPort(minPort, maxPort uint16) (uint16, error) { if maxPort < minPort { return 0, fmt.Errorf("invalid port range: min=%d, max=%d", minPort, maxPort) diff --git a/main.go b/main.go index 90a67a7..265433e 100644 --- a/main.go +++ b/main.go @@ -157,7 +157,7 @@ func runOlmMain(ctx context.Context) { func runOlmMainWithArgs(ctx context.Context, args []string) { // Log that we've entered the main function - fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args) + // fmt.Printf("runOlmMainWithArgs() called with args: %v\n", args) // Create a new FlagSet for parsing service arguments serviceFlags := flag.NewFlagSet("service", flag.ContinueOnError) @@ -179,10 +179,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { testTarget string // Add this var for test target pingInterval time.Duration pingTimeout time.Duration + doHolepunch bool ) stopHolepunch = make(chan struct{}) - stopRegister = make(chan struct{}) stopPing = make(chan struct{}) // if PANGOLIN_ENDPOINT, OLM_ID, and OLM_SECRET are set as environment variables, they will be used as default values @@ -196,6 +196,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { httpAddr = os.Getenv("HTTP_ADDR") pingIntervalStr := os.Getenv("PING_INTERVAL") pingTimeoutStr := os.Getenv("PING_TIMEOUT") + doHolepunch = os.Getenv("HOLEPUNCH") == "true" // Default to true, can be overridden by flag if endpoint == "" { serviceFlags.StringVar(&endpoint, "endpoint", "", "Endpoint of your Pangolin server") @@ -227,6 +228,8 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { if pingTimeoutStr == "" { serviceFlags.StringVar(&pingTimeoutStr, "ping-timeout", "5s", " Timeout for each ping (default 3s)") } + serviceFlags.BoolVar(&enableHTTP, "enable-http", false, "Enable HTT server for receiving connection requests") + serviceFlags.BoolVar(&doHolepunch, "holepunch", false, "Enable hole punching (default false)") // Parse the service arguments if err := serviceFlags.Parse(args); err != nil { @@ -442,7 +445,10 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { connectTimes++ - close(stopRegister) + if stopRegister != nil { + stopRegister() + stopRegister = nil + } // if there is an existing tunnel then close it if dev != nil { @@ -566,6 +572,7 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { fixKey(privateKey.String()), olm, dev, + doHolepunch, ) // loop over the sites and call ConfigurePeer for each one @@ -791,9 +798,14 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { olm.OnConnect(func() error { publicKey := privateKey.PublicKey() - logger.Debug("Public key: %s", publicKey) - go keepSendingRegistration(olm, publicKey.String()) + logger.Debug("Sending registration message to server with public key: %s and relay: %v", publicKey, !doHolepunch) + + stopRegister = olm.SendMessageInterval("olm/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "relay": !doHolepunch, + }, 1*time.Second) + go keepSendingPing(olm) if httpServer != nil { @@ -832,11 +844,9 @@ func runOlmMainWithArgs(ctx context.Context, args []string) { close(stopHolepunch) } - select { - case <-stopRegister: - // Channel already closed - default: - close(stopRegister) + if stopRegister != nil { + stopRegister() + stopRegister = nil } select { diff --git a/peermonitor/peermonitor.go b/peermonitor/peermonitor.go index 9570aec..684d767 100644 --- a/peermonitor/peermonitor.go +++ b/peermonitor/peermonitor.go @@ -26,31 +26,33 @@ type WireGuardConfig struct { // PeerMonitor handles monitoring the connection status to multiple WireGuard peers type PeerMonitor struct { - monitors map[int]*wgtester.Client - configs map[int]*WireGuardConfig - callback PeerMonitorCallback - mutex sync.Mutex - running bool - interval time.Duration - timeout time.Duration - maxAttempts int - privateKey string - wsClient *websocket.Client - device *device.Device + monitors map[int]*wgtester.Client + configs map[int]*WireGuardConfig + callback PeerMonitorCallback + mutex sync.Mutex + running bool + interval time.Duration + timeout time.Duration + maxAttempts int + privateKey string + wsClient *websocket.Client + device *device.Device + handleRelaySwitch bool // Whether to handle relay switching } // NewPeerMonitor creates a new peer monitor with the given callback -func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device) *PeerMonitor { +func NewPeerMonitor(callback PeerMonitorCallback, privateKey string, wsClient *websocket.Client, device *device.Device, handleRelaySwitch bool) *PeerMonitor { return &PeerMonitor{ - monitors: make(map[int]*wgtester.Client), - configs: make(map[int]*WireGuardConfig), - callback: callback, - interval: 1 * time.Second, // Default check interval - timeout: 2500 * time.Millisecond, - maxAttempts: 8, - privateKey: privateKey, - wsClient: wsClient, - device: device, + monitors: make(map[int]*wgtester.Client), + configs: make(map[int]*WireGuardConfig), + callback: callback, + interval: 1 * time.Second, // Default check interval + timeout: 2500 * time.Millisecond, + maxAttempts: 8, + privateKey: privateKey, + wsClient: wsClient, + device: device, + handleRelaySwitch: handleRelaySwitch, } } @@ -214,6 +216,10 @@ persistent_keepalive_interval=1`, pm.privateKey, config.PublicKey, config.Server // sendRelay sends a relay message to the server func (pm *PeerMonitor) sendRelay(siteID int) error { + if !pm.handleRelaySwitch { + return nil + } + if pm.wsClient == nil { return fmt.Errorf("websocket client is nil") } diff --git a/service_windows.go b/service_windows.go index a12cde0..f4dd7ff 100644 --- a/service_windows.go +++ b/service_windows.go @@ -379,7 +379,7 @@ func debugService(args []string) error { } } - fmt.Printf("Starting service in debug mode...\n") + // fmt.Printf("Starting service in debug mode...\n") // Start the service err := startService([]string{}) // Pass empty args since we already saved them @@ -387,8 +387,8 @@ func debugService(args []string) error { return fmt.Errorf("failed to start service: %v", err) } - fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n") - fmt.Printf("================================================================================\n") + // fmt.Printf("Service started. Watching logs (Press Ctrl+C to stop watching)...\n") + // fmt.Printf("================================================================================\n") // Watch the log file return watchLogFile(true)