New tunnel reconnect works

This commit is contained in:
Owen
2025-06-19 15:55:47 -04:00
parent 4b64b04603
commit 1c75eb3bee
3 changed files with 151 additions and 141 deletions

255
main.go
View File

@@ -4,7 +4,6 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"math"
"net/http" "net/http"
"net/netip" "net/netip"
"os" "os"
@@ -59,6 +58,16 @@ type ExitNode struct {
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"` WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
} }
type ExitNodePingResult struct {
ExitNodeID int `json:"exitNodeId"`
LatencyMs int64 `json:"latencyMs"`
Weight float64 `json:"weight"`
Error string `json:"error,omitempty"`
Name string `json:"exitNodeName"`
Endpoint string `json:"endpoint"`
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
}
var ( var (
endpoint string endpoint string
id string id string
@@ -76,7 +85,10 @@ var (
updownScript string updownScript string
tlsPrivateKey string tlsPrivateKey string
dockerSocket string dockerSocket string
pingInterval = 1 * time.Second
publicKey wgtypes.Key publicKey wgtypes.Key
pingStopChan chan struct{}
stopFunc func()
) )
func main() { func main() {
@@ -94,6 +106,7 @@ func main() {
acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true" acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT") tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET") dockerSocket = os.Getenv("DOCKER_SOCKET")
pingIntervalStr := os.Getenv("PING_INTERVAL")
if endpoint == "" { if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server") flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -130,6 +143,17 @@ func main() {
if dockerSocket == "" { if dockerSocket == "" {
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)") flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
} }
if pingIntervalStr == "" {
flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)")
}
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
if err != nil {
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 1 second\n", pingIntervalStr)
pingInterval = 1 * time.Second
}
}
// do a --version check // do a --version check
version := flag.Bool("version", false, "Print the version") version := flag.Bool("version", false, "Print the version")
@@ -216,38 +240,41 @@ func main() {
} }
} }
pingStopChan := make(chan struct{})
defer close(pingStopChan)
// Register handlers for different message types // Register handlers for different message types
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) { client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
logger.Info("Received registration message") logger.Info("Received registration message")
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping
}
if connected { if connected {
if pingStopChan != nil {
// Stop the ping check
close(pingStopChan)
pingStopChan = nil
}
// Stop proxy manager if running // Stop proxy manager if running
if pm != nil { if pm != nil {
pm.Stop() pm.Stop()
pm = nil pm = nil
} }
// Close WireGuard device if running // Close WireGuard device first - this will automatically close the TUN device
if dev != nil { if dev != nil {
dev.Close() dev.Close()
dev = nil dev = nil
} }
// Close TUN/netstack if running // Clear references but don't manually close since dev.Close() already did it
if tnet != nil { if tnet != nil {
tnet = nil tnet = nil
} }
if tun != nil { if tun != nil {
tun.Close() tun = nil // Don't call tun.Close() here since dev.Close() already closed it
tun = nil
} }
// Stop the ping check
close(pingStopChan)
// Mark as disconnected // Mark as disconnected
connected = false connected = false
} }
@@ -315,7 +342,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
// as the pings will continue in the background // as the pings will continue in the background
if !connected { if !connected {
logger.Info("Starting ping check") logger.Info("Starting ping check")
startPingCheck(tnet, wgData.ServerIP, client, pingStopChan) pingStopChan = startPingCheck(tnet, wgData.ServerIP, client)
} }
// Create proxy manager // Create proxy manager
@@ -348,30 +375,32 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) { client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
logger.Info("Received disconnect message") logger.Info("Received disconnect message")
if pingStopChan != nil {
// Stop the ping check
close(pingStopChan)
pingStopChan = nil
}
// Stop proxy manager if running // Stop proxy manager if running
if pm != nil { if pm != nil {
pm.Stop() pm.Stop()
pm = nil pm = nil
} }
// Close WireGuard device if running // Close WireGuard device first - this will automatically close the TUN device
if dev != nil { if dev != nil {
dev.Close() dev.Close()
dev = nil dev = nil
} }
// Close TUN/netstack if running // Clear references but don't manually close since dev.Close() already did it
if tnet != nil { if tnet != nil {
tnet = nil tnet = nil
} }
if tun != nil { if tun != nil {
tun.Close() tun = nil // Don't call tun.Close() here since dev.Close() already closed it
tun = nil
} }
// Stop the ping check
close(pingStopChan)
// Mark as disconnected // Mark as disconnected
connected = false connected = false
@@ -380,9 +409,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
logger.Info("Received ping message") logger.Info("Received ping message")
if stopFunc != nil {
stopFunc() // stop the ws from sending more requests
stopFunc = nil // reset stopFunc to nil to avoid double stopping
}
// Parse the incoming list of exit nodes // Parse the incoming list of exit nodes
// Exit nodes is a json
var exitNodeData ExitNodeData var exitNodeData ExitNodeData
jsonData, err := json.Marshal(msg.Data) jsonData, err := json.Marshal(msg.Data)
@@ -408,8 +440,11 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
} }
results := make([]nodeResult, len(exitNodes)) results := make([]nodeResult, len(exitNodes))
const pingAttempts = 3
for i, node := range exitNodes { for i, node := range exitNodes {
start := time.Now() var totalLatency time.Duration
var lastErr error
successes := 0
client := &http.Client{ client := &http.Client{
Timeout: 5 * time.Second, Timeout: 5 * time.Second,
} }
@@ -420,89 +455,54 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
if !strings.HasSuffix(url, "/ping") { if !strings.HasSuffix(url, "/ping") {
url = strings.TrimRight(url, "/") + "/ping" url = strings.TrimRight(url, "/") + "/ping"
} }
resp, err := client.Get(url) for j := 0; j < pingAttempts; j++ {
latency := time.Since(start) start := time.Now()
if err != nil { resp, err := client.Get(url)
logger.Warn("Failed to ping exit node %s (%s): %v", node.ID, url, err) latency := time.Since(start)
results[i] = nodeResult{Node: node, Latency: latency, Err: err} if err != nil {
continue lastErr = err
} logger.Warn("Failed to ping exit node %d (%s) attempt %d: %v", node.ID, url, j+1, err)
resp.Body.Close() continue
results[i] = nodeResult{Node: node, Latency: latency, Err: nil}
// logger.Info("Exit node %s latency: %v", node.Name, latency)
}
// we will need to tweak these
const (
latencyPenaltyExponent = 1.5 // make latency matter more
lastNodeScoreBoost = 1.10 // 10% preference for the last used node
scoreTolerancePercent = 5.0 // allow last node if within 5% of best score
)
var bestNode *ExitNode
var bestScore float64 = -1e12
var bestLatency time.Duration = 1e12
type ExitNodeScore struct {
Node ExitNode
Score float64
Latency time.Duration
}
var candidateNodes []ExitNodeScore
for _, res := range results {
if res.Err != nil || res.Node.Weight <= 0 {
continue
}
latencyMs := float64(res.Latency.Milliseconds())
score := res.Node.Weight / math.Pow(latencyMs, latencyPenaltyExponent)
// slight boost if this is the last used node
if res.Node.WasPreviouslyConnected == true {
score *= lastNodeScoreBoost
}
logger.Info("Exit node %s with score: %.2f (latency: %dms, weight: %.2f)", res.Node.Name, score, res.Latency.Milliseconds(), res.Node.Weight)
candidateNodes = append(candidateNodes, ExitNodeScore{Node: res.Node, Score: score, Latency: res.Latency})
if score > bestScore {
bestScore = score
bestLatency = res.Latency
bestNode = &res.Node
} else if score == bestScore && res.Latency < bestLatency {
bestLatency = res.Latency
bestNode = &res.Node
}
}
// check if last used node is close enough in score
for _, cand := range candidateNodes {
if cand.Node.WasPreviouslyConnected {
if bestScore-cand.Score <= bestScore*(scoreTolerancePercent/100.0) {
logger.Info("Sticking with last used exit node: %s (%s), score close enough to best", cand.Node.Name, cand.Node.Endpoint)
bestNode = &cand.Node
} }
break resp.Body.Close()
totalLatency += latency
successes++
}
var avgLatency time.Duration
if successes > 0 {
avgLatency = totalLatency / time.Duration(successes)
}
if successes == 0 {
results[i] = nodeResult{Node: node, Latency: 0, Err: lastErr}
} else {
results[i] = nodeResult{Node: node, Latency: avgLatency, Err: nil}
} }
} }
if bestNode == nil { // Prepare data to send to the cloud for selection
logger.Warn("No suitable exit node found") var pingResults []ExitNodePingResult
return for _, res := range results {
errMsg := ""
if res.Err != nil {
errMsg = res.Err.Error()
}
pingResults = append(pingResults, ExitNodePingResult{
ExitNodeID: res.Node.ID,
LatencyMs: res.Latency.Milliseconds(),
Weight: res.Node.Weight,
Error: errMsg,
Name: res.Node.Name,
Endpoint: res.Node.Endpoint,
WasPreviouslyConnected: res.Node.WasPreviouslyConnected,
})
} }
logger.Info("Selected exit node: %s (%s)", bestNode.Name, bestNode.Endpoint) // Send the ping results to the cloud for selection
stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{
err = client.SendMessage("newt/wg/register", map[string]interface{}{ "publicKey": publicKey.String(),
"publicKey": publicKey.String(), "pingResults": pingResults,
"exitNodeId": bestNode.ID, }, 1*time.Second)
}) logger.Info("Sent exit node ping results to cloud for selection")
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return
}
}) })
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
@@ -648,10 +648,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Debug("Public key: %s", publicKey) logger.Debug("Public key: %s", publicKey)
// request from the server the list of nodes to ping at newt/ping/request // request from the server the list of nodes to ping at newt/ping/request
err := client.SendMessage("newt/ping/request", map[string]interface{}{}) stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
if err != nil {
logger.Error("Failed to send ping request: %v", err)
}
if wgService != nil { if wgService != nil {
wgService.LoadRemoteConfig() wgService.LoadRemoteConfig()
@@ -699,75 +696,59 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
os.Exit(0) os.Exit(0)
} }
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, stopChan chan struct{}) { func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
initialInterval := 10 * time.Second initialInterval := pingInterval
maxInterval := 60 * time.Second maxInterval := 3 * time.Second
currentInterval := initialInterval currentInterval := initialInterval
consecutiveFailures := 0 consecutiveFailures := 0
connectionLost := false connectionLost := false
ticker := time.NewTicker(currentInterval)
defer ticker.Stop() pingStopChan := make(chan struct{})
go func() { go func() {
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
_, err := ping(tnet, serverIP) _, err := ping(tnet, serverIP)
if err != nil { if err != nil {
consecutiveFailures++ consecutiveFailures++
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
// Check if this is the first failure (connection just lost) if consecutiveFailures >= 3 && currentInterval < maxInterval {
if !connectionLost { if !connectionLost {
connectionLost = true connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.") stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.") }
}
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
consecutiveFailures, err)
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
// Increase interval if we have consistent failures, with a maximum cap
if consecutiveFailures >= 5 && currentInterval < maxInterval {
// Increase by 50% each time, up to the maximum
currentInterval = time.Duration(float64(currentInterval) * 1.5) currentInterval = time.Duration(float64(currentInterval) * 1.5)
if currentInterval > maxInterval { if currentInterval > maxInterval {
currentInterval = maxInterval currentInterval = maxInterval
} }
ticker.Reset(currentInterval) ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures", logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
currentInterval)
// Restart the connection flow
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
if err != nil {
logger.Error("Failed to send ping request: %v", err)
}
} }
} else { } else {
// Check if connection was previously lost and is now restored
if connectionLost { if connectionLost {
connectionLost = false connectionLost = false
logger.Info("Connection to server restored!") logger.Info("Connection to server restored!")
} }
// On success, if we've backed off, gradually return to normal interval
if currentInterval > initialInterval { if currentInterval > initialInterval {
currentInterval = time.Duration(float64(currentInterval) * 0.8) currentInterval = time.Duration(float64(currentInterval) * 0.8)
if currentInterval < initialInterval { if currentInterval < initialInterval {
currentInterval = initialInterval currentInterval = initialInterval
} }
ticker.Reset(currentInterval) ticker.Reset(currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping", logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
currentInterval)
} }
consecutiveFailures = 0 consecutiveFailures = 0
} }
case <-stopChan: case <-pingStopChan:
logger.Info("Stopping ping check") logger.Info("Stopping ping check")
return return
} }
} }
}() }()
return pingStopChan
} }

View File

@@ -44,7 +44,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
requestPing := icmp.Echo{ requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16), Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"), Data: []byte("f"),
} }
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
@@ -52,7 +52,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err) return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
} }
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { if err := socket.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil {
return 0, fmt.Errorf("failed to set read deadline: %w", err) return 0, fmt.Errorf("failed to set read deadline: %w", err)
} }
@@ -84,12 +84,14 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
latency := time.Since(start) latency := time.Since(start)
logger.Debug("Ping to %s successful, latency: %v", dst, latency)
return latency, nil return latency, nil
} }
func pingWithRetry(tnet *netstack.Net, dst string) error { func pingWithRetry(tnet *netstack.Net, dst string) error {
const ( const (
initialMaxAttempts = 15 initialMaxAttempts = 5
initialRetryDelay = 2 * time.Second initialRetryDelay = 2 * time.Second
maxRetryDelay = 60 * time.Second // Cap the maximum delay maxRetryDelay = 60 * time.Second // Cap the maximum delay
) )

View File

@@ -9,11 +9,12 @@ import (
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"software.sslmate.com/src/go-pkcs12"
"strings" "strings"
"sync" "sync"
"time" "time"
"software.sslmate.com/src/go-pkcs12"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
@@ -126,6 +127,32 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
return c.conn.WriteJSON(msg) return c.conn.WriteJSON(msg)
} }
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
stopChan := make(chan struct{})
go func() {
err := c.SendMessage(messageType, data) // Send immediately
if err != nil {
logger.Error("Failed to send initial message: %v", err)
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
err = c.SendMessage(messageType, data)
if err != nil {
logger.Error("Failed to send message: %v", err)
}
case <-stopChan:
return
}
}
}()
return func() {
close(stopChan)
}
}
// RegisterHandler registers a handler for a specific message type // RegisterHandler registers a handler for a specific message type
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) { func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
c.handlersMux.Lock() c.handlersMux.Lock()