New tunnel reconnect works

This commit is contained in:
Owen
2025-06-19 15:55:47 -04:00
parent 4b64b04603
commit 1c75eb3bee
3 changed files with 151 additions and 141 deletions

255
main.go
View File

@@ -4,7 +4,6 @@ import (
"encoding/json"
"flag"
"fmt"
"math"
"net/http"
"net/netip"
"os"
@@ -59,6 +58,16 @@ type ExitNode struct {
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
}
type ExitNodePingResult struct {
ExitNodeID int `json:"exitNodeId"`
LatencyMs int64 `json:"latencyMs"`
Weight float64 `json:"weight"`
Error string `json:"error,omitempty"`
Name string `json:"exitNodeName"`
Endpoint string `json:"endpoint"`
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
}
var (
endpoint string
id string
@@ -76,7 +85,10 @@ var (
updownScript string
tlsPrivateKey string
dockerSocket string
pingInterval = 1 * time.Second
publicKey wgtypes.Key
pingStopChan chan struct{}
stopFunc func()
)
func main() {
@@ -94,6 +106,7 @@ func main() {
acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET")
pingIntervalStr := os.Getenv("PING_INTERVAL")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -130,6 +143,17 @@ func main() {
if dockerSocket == "" {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
}
if pingIntervalStr == "" {
flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)")
}
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 1 second\n", pingIntervalStr)
pingInterval = 1 * time.Second
}
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@@ -216,38 +240,41 @@ func main() {
}
}
pingStopChan := make(chan struct{})
defer close(pingStopChan)
// Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received registration message")
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping
}
if connected {
if pingStopChan != nil {
// Stop the ping check
close(pingStopChan)
pingStopChan = nil
}
// Stop proxy manager if running
if pm != nil {
pm.Stop()
pm = nil
}
// Close WireGuard device if running
// Close WireGuard device first - this will automatically close the TUN device
if dev != nil {
dev.Close()
dev = nil
}
// Close TUN/netstack if running
// Clear references but don't manually close since dev.Close() already did it
if tnet != nil {
tnet = nil
}
if tun != nil {
tun.Close()
tun = nil
tun = nil // Don't call tun.Close() here since dev.Close() already closed it
}
// Stop the ping check
close(pingStopChan)
// Mark as disconnected
connected = false
}
@@ -315,7 +342,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
// as the pings will continue in the background
if !connected {
logger.Info("Starting ping check")
startPingCheck(tnet, wgData.ServerIP, client, pingStopChan)
pingStopChan = startPingCheck(tnet, wgData.ServerIP, client)
}
// Create proxy manager
@@ -348,30 +375,32 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
logger.Info("Received disconnect message")
if pingStopChan != nil {
// Stop the ping check
close(pingStopChan)
pingStopChan = nil
}
// Stop proxy manager if running
if pm != nil {
pm.Stop()
pm = nil
}
// Close WireGuard device if running
// Close WireGuard device first - this will automatically close the TUN device
if dev != nil {
dev.Close()
dev = nil
}
// Close TUN/netstack if running
// Clear references but don't manually close since dev.Close() already did it
if tnet != nil {
tnet = nil
}
if tun != nil {
tun.Close()
tun = nil
tun = nil // Don't call tun.Close() here since dev.Close() already closed it
}
// Stop the ping check
close(pingStopChan)
// Mark as disconnected
connected = false
@@ -380,9 +409,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
logger.Info("Received ping message")
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping
}
// Parse the incoming list of exit nodes
// Exit nodes is a json
var exitNodeData ExitNodeData
jsonData, err := json.Marshal(msg.Data)
@@ -408,8 +440,11 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
results := make([]nodeResult, len(exitNodes))
const pingAttempts = 3
for i, node := range exitNodes {
start := time.Now()
var totalLatency time.Duration
var lastErr error
successes := 0
client := &http.Client{
Timeout: 5 * time.Second,
}
@@ -420,89 +455,54 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
if !strings.HasSuffix(url, "/ping") {
url = strings.TrimRight(url, "/") + "/ping"
}
resp, err := client.Get(url)
latency := time.Since(start)
if err != nil {
logger.Warn("Failed to ping exit node %s (%s): %v", node.ID, url, err)
results[i] = nodeResult{Node: node, Latency: latency, Err: err}
continue
}
resp.Body.Close()
results[i] = nodeResult{Node: node, Latency: latency, Err: nil}
// logger.Info("Exit node %s latency: %v", node.Name, latency)
}
// we will need to tweak these
const (
latencyPenaltyExponent = 1.5 // make latency matter more
lastNodeScoreBoost = 1.10 // 10% preference for the last used node
scoreTolerancePercent = 5.0 // allow last node if within 5% of best score
)
var bestNode *ExitNode
var bestScore float64 = -1e12
var bestLatency time.Duration = 1e12
type ExitNodeScore struct {
Node ExitNode
Score float64
Latency time.Duration
}
var candidateNodes []ExitNodeScore
for _, res := range results {
if res.Err != nil || res.Node.Weight <= 0 {
continue
}
latencyMs := float64(res.Latency.Milliseconds())
score := res.Node.Weight / math.Pow(latencyMs, latencyPenaltyExponent)
// slight boost if this is the last used node
if res.Node.WasPreviouslyConnected == true {
score *= lastNodeScoreBoost
}
logger.Info("Exit node %s with score: %.2f (latency: %dms, weight: %.2f)", res.Node.Name, score, res.Latency.Milliseconds(), res.Node.Weight)
candidateNodes = append(candidateNodes, ExitNodeScore{Node: res.Node, Score: score, Latency: res.Latency})
if score > bestScore {
bestScore = score
bestLatency = res.Latency
bestNode = &res.Node
} else if score == bestScore && res.Latency < bestLatency {
bestLatency = res.Latency
bestNode = &res.Node
}
}
// check if last used node is close enough in score
for _, cand := range candidateNodes {
if cand.Node.WasPreviouslyConnected {
if bestScore-cand.Score <= bestScore*(scoreTolerancePercent/100.0) {
logger.Info("Sticking with last used exit node: %s (%s), score close enough to best", cand.Node.Name, cand.Node.Endpoint)
bestNode = &cand.Node
for j := 0; j < pingAttempts; j++ {
start := time.Now()
resp, err := client.Get(url)
latency := time.Since(start)
if err != nil {
lastErr = err
logger.Warn("Failed to ping exit node %d (%s) attempt %d: %v", node.ID, url, j+1, err)
continue
}
break
resp.Body.Close()
totalLatency += latency
successes++
}
var avgLatency time.Duration
if successes > 0 {
avgLatency = totalLatency / time.Duration(successes)
}
if successes == 0 {
results[i] = nodeResult{Node: node, Latency: 0, Err: lastErr}
} else {
results[i] = nodeResult{Node: node, Latency: avgLatency, Err: nil}
}
}
if bestNode == nil {
logger.Warn("No suitable exit node found")
return
// Prepare data to send to the cloud for selection
var pingResults []ExitNodePingResult
for _, res := range results {
errMsg := ""
if res.Err != nil {
errMsg = res.Err.Error()
}
pingResults = append(pingResults, ExitNodePingResult{
ExitNodeID: res.Node.ID,
LatencyMs: res.Latency.Milliseconds(),
Weight: res.Node.Weight,
Error: errMsg,
Name: res.Node.Name,
Endpoint: res.Node.Endpoint,
WasPreviouslyConnected: res.Node.WasPreviouslyConnected,
})
}
logger.Info("Selected exit node: %s (%s)", bestNode.Name, bestNode.Endpoint)
err = client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"exitNodeId": bestNode.ID,
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return
}
// Send the ping results to the cloud for selection
stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
"pingResults": pingResults,
}, 1*time.Second)
logger.Info("Sent exit node ping results to cloud for selection")
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
@@ -648,10 +648,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Debug("Public key: %s", publicKey)
// request from the server the list of nodes to ping at newt/ping/request
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
if err != nil {
logger.Error("Failed to send ping request: %v", err)
}
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
if wgService != nil {
wgService.LoadRemoteConfig()
@@ -699,75 +696,59 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
os.Exit(0)
}
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, stopChan chan struct{}) {
initialInterval := 10 * time.Second
maxInterval := 60 * time.Second
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
initialInterval := pingInterval
maxInterval := 3 * time.Second
currentInterval := initialInterval
consecutiveFailures := 0
connectionLost := false
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
pingStopChan := make(chan struct{})
go func() {
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, err := ping(tnet, serverIP)
if err != nil {
consecutiveFailures++
// Check if this is the first failure (connection just lost)
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
}
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
consecutiveFailures, err)
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
// Increase interval if we have consistent failures, with a maximum cap
if consecutiveFailures >= 5 && currentInterval < maxInterval {
// Increase by 50% each time, up to the maximum
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
if consecutiveFailures >= 3 && currentInterval < maxInterval {
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
}
currentInterval = time.Duration(float64(currentInterval) * 1.5)
if currentInterval > maxInterval {
currentInterval = maxInterval
}
ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures",
currentInterval)
// Restart the connection flow
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
if err != nil {
logger.Error("Failed to send ping request: %v", err)
}
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
}
} else {
// Check if connection was previously lost and is now restored
if connectionLost {
connectionLost = false
logger.Info("Connection to server restored!")
}
// On success, if we've backed off, gradually return to normal interval
if currentInterval > initialInterval {
currentInterval = time.Duration(float64(currentInterval) * 0.8)
if currentInterval < initialInterval {
currentInterval = initialInterval
}
ticker.Reset(currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping",
currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
}
consecutiveFailures = 0
}
case <-stopChan:
case <-pingStopChan:
logger.Info("Stopping ping check")
return
}
}
}()
return pingStopChan
}

View File

@@ -44,7 +44,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
Data: []byte("f"),
}
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
@@ -52,7 +52,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil {
return 0, fmt.Errorf("failed to set read deadline: %w", err)
}
@@ -84,12 +84,14 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
latency := time.Since(start)
logger.Debug("Ping to %s successful, latency: %v", dst, latency)
return latency, nil
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
const (
initialMaxAttempts = 15
initialMaxAttempts = 5
initialRetryDelay = 2 * time.Second
maxRetryDelay = 60 * time.Second // Cap the maximum delay
)

View File

@@ -9,11 +9,12 @@ import (
"net/http"
"net/url"
"os"
"software.sslmate.com/src/go-pkcs12"
"strings"
"sync"
"time"
"software.sslmate.com/src/go-pkcs12"
"github.com/fosrl/newt/logger"
"github.com/gorilla/websocket"
)
@@ -126,6 +127,32 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return c.conn.WriteJSON(msg)
}
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
stopChan := make(chan struct{})
go func() {
err := c.SendMessage(messageType, data) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err = c.SendMessage(messageType, data)
if err != nil {
logger.Error("Failed to send message: %v", err)
}
case <-stopChan:
return
}
}
}()
return func() {
close(stopChan)
}
}
// RegisterHandler registers a handler for a specific message type
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlersMux.Lock()