From bb1318278a9c85b631f857778415be4a327ae343 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 19 Jun 2025 15:59:21 -0400 Subject: [PATCH] Reorg and add timeout --- main.go | 72 +++++++++++---------------------------------------------- util.go | 68 +++++++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 77 insertions(+), 63 deletions(-) diff --git a/main.go b/main.go index a646672..b4025f8 100644 --- a/main.go +++ b/main.go @@ -86,6 +86,7 @@ var ( tlsPrivateKey string dockerSocket string pingInterval = 1 * time.Second + pingTimeout = 2 * time.Second publicKey wgtypes.Key pingStopChan chan struct{} stopFunc func() @@ -107,6 +108,7 @@ func main() { tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") dockerSocket = os.Getenv("DOCKER_SOCKET") pingIntervalStr := os.Getenv("PING_INTERVAL") + pingTimeoutStr := os.Getenv("PING_TIMEOUT") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -146,6 +148,9 @@ func main() { if pingIntervalStr == "" { flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)") } + if pingTimeoutStr == "" { + flag.StringVar(&pingTimeoutStr, "ping-timeout", "2s", " Timeout for each ping (default 2s)") + } if pingIntervalStr != "" { pingInterval, err = time.ParseDuration(pingIntervalStr) @@ -155,6 +160,14 @@ func main() { } } + if pingTimeoutStr != "" { + pingTimeout, err = time.ParseDuration(pingTimeoutStr) + if err != nil { + fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 2 seconds\n", pingTimeoutStr) + pingTimeout = 2 * time.Second + } + } + // do a --version check version := flag.Bool("version", false, "Print the version") @@ -336,7 +349,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Info("WireGuard device created. Lets ping the server now...") // Even if pingWithRetry returns an error, it will continue trying in the background - _ = pingWithRetry(tnet, wgData.ServerIP) + _ = pingWithRetry(tnet, wgData.ServerIP, pingTimeout) // Always mark as connected and start the proxy manager regardless of initial ping result // as the pings will continue in the background @@ -695,60 +708,3 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Info("Exiting...") os.Exit(0) } - -func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} { - initialInterval := pingInterval - maxInterval := 3 * time.Second - currentInterval := initialInterval - consecutiveFailures := 0 - connectionLost := false - - pingStopChan := make(chan struct{}) - - go func() { - ticker := time.NewTicker(currentInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - _, err := ping(tnet, serverIP) - if err != nil { - consecutiveFailures++ - logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err) - if consecutiveFailures >= 3 && currentInterval < maxInterval { - if !connectionLost { - connectionLost = true - logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") - stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) - } - currentInterval = time.Duration(float64(currentInterval) * 1.5) - if currentInterval > maxInterval { - currentInterval = maxInterval - } - ticker.Reset(currentInterval) - logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval) - } - } else { - if connectionLost { - connectionLost = false - logger.Info("Connection to server restored!") - } - if currentInterval > initialInterval { - currentInterval = time.Duration(float64(currentInterval) * 0.8) - if currentInterval < initialInterval { - currentInterval = initialInterval - } - ticker.Reset(currentInterval) - logger.Info("Decreased ping check interval to %v after successful ping", currentInterval) - } - consecutiveFailures = 0 - } - case <-pingStopChan: - logger.Info("Stopping ping check") - return - } - } - }() - - return pingStopChan -} diff --git a/util.go b/util.go index 43a48bf..a10c94f 100644 --- a/util.go +++ b/util.go @@ -13,6 +13,7 @@ import ( "github.com/fosrl/newt/logger" "github.com/fosrl/newt/proxy" + "github.com/fosrl/newt/websocket" "golang.org/x/exp/rand" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" @@ -34,7 +35,7 @@ func fixKey(key string) string { return hex.EncodeToString(decoded) } -func ping(tnet *netstack.Net, dst string) (time.Duration, error) { +func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) { logger.Debug("Pinging %s", dst) socket, err := tnet.Dial("ping4", dst) if err != nil { @@ -52,7 +53,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) { return 0, fmt.Errorf("failed to marshal ICMP message: %w", err) } - if err := socket.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil { + if err := socket.SetReadDeadline(time.Now().Add(timeout)); err != nil { return 0, fmt.Errorf("failed to set read deadline: %w", err) } @@ -89,7 +90,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) { return latency, nil } -func pingWithRetry(tnet *netstack.Net, dst string) error { +func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) error { const ( initialMaxAttempts = 5 initialRetryDelay = 2 * time.Second @@ -101,7 +102,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error { // First try with the initial parameters logger.Info("Ping attempt %d", attempt) - if latency, err := ping(tnet, dst); err == nil { + if latency, err := ping(tnet, dst, timeout); err == nil { // Successful ping logger.Info("Ping latency: %v", latency) @@ -118,7 +119,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error { for { logger.Info("Ping attempt %d", attempt) - if latency, err := ping(tnet, dst); err != nil { + if latency, err := ping(tnet, dst, timeout); err != nil { logger.Warn("Ping attempt %d failed: %v", attempt, err) // Increase delay after certain thresholds but cap it @@ -146,6 +147,63 @@ func pingWithRetry(tnet *netstack.Net, dst string) error { return fmt.Errorf("initial ping attempts failed, continuing in background") } +func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} { + initialInterval := pingInterval + maxInterval := 3 * time.Second + currentInterval := initialInterval + consecutiveFailures := 0 + connectionLost := false + + pingStopChan := make(chan struct{}) + + go func() { + ticker := time.NewTicker(currentInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + _, err := ping(tnet, serverIP, pingTimeout) + if err != nil { + consecutiveFailures++ + logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err) + if consecutiveFailures >= 3 && currentInterval < maxInterval { + if !connectionLost { + connectionLost = true + logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + } + currentInterval = time.Duration(float64(currentInterval) * 1.5) + if currentInterval > maxInterval { + currentInterval = maxInterval + } + ticker.Reset(currentInterval) + logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval) + } + } else { + if connectionLost { + connectionLost = false + logger.Info("Connection to server restored!") + } + if currentInterval > initialInterval { + currentInterval = time.Duration(float64(currentInterval) * 0.8) + if currentInterval < initialInterval { + currentInterval = initialInterval + } + ticker.Reset(currentInterval) + logger.Info("Decreased ping check interval to %v after successful ping", currentInterval) + } + consecutiveFailures = 0 + } + case <-pingStopChan: + logger.Info("Stopping ping check") + return + } + } + }() + + return pingStopChan +} + func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { case "DEBUG":