From 1c75eb3bee208a658b3cd03caa8b8956284e6972 Mon Sep 17 00:00:00 2001 From: Owen Date: Thu, 19 Jun 2025 15:55:47 -0400 Subject: [PATCH] New tunnel reconnect works --- main.go | 255 ++++++++++++++++++++------------------------ util.go | 8 +- websocket/client.go | 29 ++++- 3 files changed, 151 insertions(+), 141 deletions(-) diff --git a/main.go b/main.go index 59ad194..a646672 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,6 @@ import ( "encoding/json" "flag" "fmt" - "math" "net/http" "net/netip" "os" @@ -59,6 +58,16 @@ type ExitNode struct { WasPreviouslyConnected bool `json:"wasPreviouslyConnected"` } +type ExitNodePingResult struct { + ExitNodeID int `json:"exitNodeId"` + LatencyMs int64 `json:"latencyMs"` + Weight float64 `json:"weight"` + Error string `json:"error,omitempty"` + Name string `json:"exitNodeName"` + Endpoint string `json:"endpoint"` + WasPreviouslyConnected bool `json:"wasPreviouslyConnected"` +} + var ( endpoint string id string @@ -76,7 +85,10 @@ var ( updownScript string tlsPrivateKey string dockerSocket string + pingInterval = 1 * time.Second publicKey wgtypes.Key + pingStopChan chan struct{} + stopFunc func() ) func main() { @@ -94,6 +106,7 @@ func main() { acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true" tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") dockerSocket = os.Getenv("DOCKER_SOCKET") + pingIntervalStr := os.Getenv("PING_INTERVAL") if endpoint == "" { flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") @@ -130,6 +143,17 @@ func main() { if dockerSocket == "" { flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)") } + if pingIntervalStr == "" { + flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)") + } + + if pingIntervalStr != "" { + pingInterval, err = time.ParseDuration(pingIntervalStr) + if err != nil { + fmt.Printf("Invalid PING_INTERVAL value: %s, using default 1 second\n", pingIntervalStr) + pingInterval = 1 * time.Second + } + } // do a --version check version := flag.Bool("version", false, "Print the version") @@ -216,38 +240,41 @@ func main() { } } - pingStopChan := make(chan struct{}) - defer close(pingStopChan) - // Register handlers for different message types client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { logger.Info("Received registration message") + if stopFunc != nil { + stopFunc() // stop the ws from sending more requests + stopFunc = nil // reset stopFunc to nil to avoid double stopping + } if connected { + if pingStopChan != nil { + // Stop the ping check + close(pingStopChan) + pingStopChan = nil + } + // Stop proxy manager if running if pm != nil { pm.Stop() pm = nil } - // Close WireGuard device if running + // Close WireGuard device first - this will automatically close the TUN device if dev != nil { dev.Close() dev = nil } - // Close TUN/netstack if running + // Clear references but don't manually close since dev.Close() already did it if tnet != nil { tnet = nil } if tun != nil { - tun.Close() - tun = nil + tun = nil // Don't call tun.Close() here since dev.Close() already closed it } - // Stop the ping check - close(pingStopChan) - // Mark as disconnected connected = false } @@ -315,7 +342,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub // as the pings will continue in the background if !connected { logger.Info("Starting ping check") - startPingCheck(tnet, wgData.ServerIP, client, pingStopChan) + pingStopChan = startPingCheck(tnet, wgData.ServerIP, client) } // Create proxy manager @@ -348,30 +375,32 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) { logger.Info("Received disconnect message") + if pingStopChan != nil { + // Stop the ping check + close(pingStopChan) + pingStopChan = nil + } + // Stop proxy manager if running if pm != nil { pm.Stop() pm = nil } - // Close WireGuard device if running + // Close WireGuard device first - this will automatically close the TUN device if dev != nil { dev.Close() dev = nil } - // Close TUN/netstack if running + // Clear references but don't manually close since dev.Close() already did it if tnet != nil { tnet = nil } if tun != nil { - tun.Close() - tun = nil + tun = nil // Don't call tun.Close() here since dev.Close() already closed it } - // Stop the ping check - close(pingStopChan) - // Mark as disconnected connected = false @@ -380,9 +409,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { logger.Info("Received ping message") + if stopFunc != nil { + stopFunc() // stop the ws from sending more requests + stopFunc = nil // reset stopFunc to nil to avoid double stopping + } // Parse the incoming list of exit nodes - // Exit nodes is a json var exitNodeData ExitNodeData jsonData, err := json.Marshal(msg.Data) @@ -408,8 +440,11 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } results := make([]nodeResult, len(exitNodes)) + const pingAttempts = 3 for i, node := range exitNodes { - start := time.Now() + var totalLatency time.Duration + var lastErr error + successes := 0 client := &http.Client{ Timeout: 5 * time.Second, } @@ -420,89 +455,54 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub if !strings.HasSuffix(url, "/ping") { url = strings.TrimRight(url, "/") + "/ping" } - resp, err := client.Get(url) - latency := time.Since(start) - if err != nil { - logger.Warn("Failed to ping exit node %s (%s): %v", node.ID, url, err) - results[i] = nodeResult{Node: node, Latency: latency, Err: err} - continue - } - resp.Body.Close() - results[i] = nodeResult{Node: node, Latency: latency, Err: nil} - // logger.Info("Exit node %s latency: %v", node.Name, latency) - } - - // we will need to tweak these - const ( - latencyPenaltyExponent = 1.5 // make latency matter more - lastNodeScoreBoost = 1.10 // 10% preference for the last used node - scoreTolerancePercent = 5.0 // allow last node if within 5% of best score - ) - - var bestNode *ExitNode - var bestScore float64 = -1e12 - var bestLatency time.Duration = 1e12 - - type ExitNodeScore struct { - Node ExitNode - Score float64 - Latency time.Duration - } - var candidateNodes []ExitNodeScore - - for _, res := range results { - if res.Err != nil || res.Node.Weight <= 0 { - continue - } - - latencyMs := float64(res.Latency.Milliseconds()) - score := res.Node.Weight / math.Pow(latencyMs, latencyPenaltyExponent) - - // slight boost if this is the last used node - if res.Node.WasPreviouslyConnected == true { - score *= lastNodeScoreBoost - } - - logger.Info("Exit node %s with score: %.2f (latency: %dms, weight: %.2f)", res.Node.Name, score, res.Latency.Milliseconds(), res.Node.Weight) - - candidateNodes = append(candidateNodes, ExitNodeScore{Node: res.Node, Score: score, Latency: res.Latency}) - - if score > bestScore { - bestScore = score - bestLatency = res.Latency - bestNode = &res.Node - } else if score == bestScore && res.Latency < bestLatency { - bestLatency = res.Latency - bestNode = &res.Node - } - } - - // check if last used node is close enough in score - for _, cand := range candidateNodes { - if cand.Node.WasPreviouslyConnected { - if bestScore-cand.Score <= bestScore*(scoreTolerancePercent/100.0) { - logger.Info("Sticking with last used exit node: %s (%s), score close enough to best", cand.Node.Name, cand.Node.Endpoint) - bestNode = &cand.Node + for j := 0; j < pingAttempts; j++ { + start := time.Now() + resp, err := client.Get(url) + latency := time.Since(start) + if err != nil { + lastErr = err + logger.Warn("Failed to ping exit node %d (%s) attempt %d: %v", node.ID, url, j+1, err) + continue } - break + resp.Body.Close() + totalLatency += latency + successes++ + } + var avgLatency time.Duration + if successes > 0 { + avgLatency = totalLatency / time.Duration(successes) + } + if successes == 0 { + results[i] = nodeResult{Node: node, Latency: 0, Err: lastErr} + } else { + results[i] = nodeResult{Node: node, Latency: avgLatency, Err: nil} } } - if bestNode == nil { - logger.Warn("No suitable exit node found") - return + // Prepare data to send to the cloud for selection + var pingResults []ExitNodePingResult + for _, res := range results { + errMsg := "" + if res.Err != nil { + errMsg = res.Err.Error() + } + pingResults = append(pingResults, ExitNodePingResult{ + ExitNodeID: res.Node.ID, + LatencyMs: res.Latency.Milliseconds(), + Weight: res.Node.Weight, + Error: errMsg, + Name: res.Node.Name, + Endpoint: res.Node.Endpoint, + WasPreviouslyConnected: res.Node.WasPreviouslyConnected, + }) } - logger.Info("Selected exit node: %s (%s)", bestNode.Name, bestNode.Endpoint) - - err = client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": publicKey.String(), - "exitNodeId": bestNode.ID, - }) - if err != nil { - logger.Error("Failed to send registration message: %v", err) - return - } + // Send the ping results to the cloud for selection + stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + "pingResults": pingResults, + }, 1*time.Second) + logger.Info("Sent exit node ping results to cloud for selection") }) client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { @@ -648,10 +648,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub logger.Debug("Public key: %s", publicKey) // request from the server the list of nodes to ping at newt/ping/request - err := client.SendMessage("newt/ping/request", map[string]interface{}{}) - if err != nil { - logger.Error("Failed to send ping request: %v", err) - } + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) if wgService != nil { wgService.LoadRemoteConfig() @@ -699,75 +696,59 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub os.Exit(0) } -func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, stopChan chan struct{}) { - initialInterval := 10 * time.Second - maxInterval := 60 * time.Second +func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} { + initialInterval := pingInterval + maxInterval := 3 * time.Second currentInterval := initialInterval consecutiveFailures := 0 connectionLost := false - ticker := time.NewTicker(currentInterval) - defer ticker.Stop() + + pingStopChan := make(chan struct{}) go func() { + ticker := time.NewTicker(currentInterval) + defer ticker.Stop() for { select { case <-ticker.C: _, err := ping(tnet, serverIP) if err != nil { consecutiveFailures++ - - // Check if this is the first failure (connection just lost) - if !connectionLost { - connectionLost = true - logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") - logger.Warn("Please check your internet connection and ensure the Pangolin server is online.") - logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.") - } - - logger.Warn("Periodic ping failed (%d consecutive failures): %v", - consecutiveFailures, err) - logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?") - - // Increase interval if we have consistent failures, with a maximum cap - if consecutiveFailures >= 5 && currentInterval < maxInterval { - // Increase by 50% each time, up to the maximum + logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err) + if consecutiveFailures >= 3 && currentInterval < maxInterval { + if !connectionLost { + connectionLost = true + logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") + stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second) + } currentInterval = time.Duration(float64(currentInterval) * 1.5) if currentInterval > maxInterval { currentInterval = maxInterval } ticker.Reset(currentInterval) - logger.Debug("Increased ping check interval to %v due to consecutive failures", - currentInterval) - - // Restart the connection flow - err := client.SendMessage("newt/ping/request", map[string]interface{}{}) - if err != nil { - logger.Error("Failed to send ping request: %v", err) - } + logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval) } } else { - // Check if connection was previously lost and is now restored if connectionLost { connectionLost = false logger.Info("Connection to server restored!") } - - // On success, if we've backed off, gradually return to normal interval if currentInterval > initialInterval { currentInterval = time.Duration(float64(currentInterval) * 0.8) if currentInterval < initialInterval { currentInterval = initialInterval } ticker.Reset(currentInterval) - logger.Info("Decreased ping check interval to %v after successful ping", - currentInterval) + logger.Info("Decreased ping check interval to %v after successful ping", currentInterval) } consecutiveFailures = 0 } - case <-stopChan: + case <-pingStopChan: logger.Info("Stopping ping check") return } } }() + + return pingStopChan } diff --git a/util.go b/util.go index 808ae2b..43a48bf 100644 --- a/util.go +++ b/util.go @@ -44,7 +44,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) { requestPing := icmp.Echo{ Seq: rand.Intn(1 << 16), - Data: []byte("gopher burrow"), + Data: []byte("f"), } icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) @@ -52,7 +52,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) { return 0, fmt.Errorf("failed to marshal ICMP message: %w", err) } - if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { + if err := socket.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil { return 0, fmt.Errorf("failed to set read deadline: %w", err) } @@ -84,12 +84,14 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) { latency := time.Since(start) + logger.Debug("Ping to %s successful, latency: %v", dst, latency) + return latency, nil } func pingWithRetry(tnet *netstack.Net, dst string) error { const ( - initialMaxAttempts = 15 + initialMaxAttempts = 5 initialRetryDelay = 2 * time.Second maxRetryDelay = 60 * time.Second // Cap the maximum delay ) diff --git a/websocket/client.go b/websocket/client.go index 1d75ea8..6b34627 100644 --- a/websocket/client.go +++ b/websocket/client.go @@ -9,11 +9,12 @@ import ( "net/http" "net/url" "os" - "software.sslmate.com/src/go-pkcs12" "strings" "sync" "time" + "software.sslmate.com/src/go-pkcs12" + "github.com/fosrl/newt/logger" "github.com/gorilla/websocket" ) @@ -126,6 +127,32 @@ func (c *Client) SendMessage(messageType string, data interface{}) error { return c.conn.WriteJSON(msg) } +func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) { + stopChan := make(chan struct{}) + go func() { + err := c.SendMessage(messageType, data) // Send immediately + if err != nil { + logger.Error("Failed to send initial message: %v", err) + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + err = c.SendMessage(messageType, data) + if err != nil { + logger.Error("Failed to send message: %v", err) + } + case <-stopChan: + return + } + } + }() + return func() { + close(stopChan) + } +} + // RegisterHandler registers a handler for a specific message type func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { c.handlersMux.Lock()