mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
New tunnel reconnect works
This commit is contained in:
255
main.go
255
main.go
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
@@ -59,6 +58,16 @@ type ExitNode struct {
|
||||
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
|
||||
}
|
||||
|
||||
type ExitNodePingResult struct {
|
||||
ExitNodeID int `json:"exitNodeId"`
|
||||
LatencyMs int64 `json:"latencyMs"`
|
||||
Weight float64 `json:"weight"`
|
||||
Error string `json:"error,omitempty"`
|
||||
Name string `json:"exitNodeName"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
|
||||
}
|
||||
|
||||
var (
|
||||
endpoint string
|
||||
id string
|
||||
@@ -76,7 +85,10 @@ var (
|
||||
updownScript string
|
||||
tlsPrivateKey string
|
||||
dockerSocket string
|
||||
pingInterval = 1 * time.Second
|
||||
publicKey wgtypes.Key
|
||||
pingStopChan chan struct{}
|
||||
stopFunc func()
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -94,6 +106,7 @@ func main() {
|
||||
acceptClients = os.Getenv("ACCEPT_CLIENTS") == "true"
|
||||
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
||||
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||
|
||||
if endpoint == "" {
|
||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||
@@ -130,6 +143,17 @@ func main() {
|
||||
if dockerSocket == "" {
|
||||
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
|
||||
}
|
||||
if pingIntervalStr == "" {
|
||||
flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)")
|
||||
}
|
||||
|
||||
if pingIntervalStr != "" {
|
||||
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
||||
if err != nil {
|
||||
fmt.Printf("Invalid PING_INTERVAL value: %s, using default 1 second\n", pingIntervalStr)
|
||||
pingInterval = 1 * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
// do a --version check
|
||||
version := flag.Bool("version", false, "Print the version")
|
||||
@@ -216,38 +240,41 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
pingStopChan := make(chan struct{})
|
||||
defer close(pingStopChan)
|
||||
|
||||
// Register handlers for different message types
|
||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received registration message")
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||
}
|
||||
|
||||
if connected {
|
||||
if pingStopChan != nil {
|
||||
// Stop the ping check
|
||||
close(pingStopChan)
|
||||
pingStopChan = nil
|
||||
}
|
||||
|
||||
// Stop proxy manager if running
|
||||
if pm != nil {
|
||||
pm.Stop()
|
||||
pm = nil
|
||||
}
|
||||
|
||||
// Close WireGuard device if running
|
||||
// Close WireGuard device first - this will automatically close the TUN device
|
||||
if dev != nil {
|
||||
dev.Close()
|
||||
dev = nil
|
||||
}
|
||||
|
||||
// Close TUN/netstack if running
|
||||
// Clear references but don't manually close since dev.Close() already did it
|
||||
if tnet != nil {
|
||||
tnet = nil
|
||||
}
|
||||
if tun != nil {
|
||||
tun.Close()
|
||||
tun = nil
|
||||
tun = nil // Don't call tun.Close() here since dev.Close() already closed it
|
||||
}
|
||||
|
||||
// Stop the ping check
|
||||
close(pingStopChan)
|
||||
|
||||
// Mark as disconnected
|
||||
connected = false
|
||||
}
|
||||
@@ -315,7 +342,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
// as the pings will continue in the background
|
||||
if !connected {
|
||||
logger.Info("Starting ping check")
|
||||
startPingCheck(tnet, wgData.ServerIP, client, pingStopChan)
|
||||
pingStopChan = startPingCheck(tnet, wgData.ServerIP, client)
|
||||
}
|
||||
|
||||
// Create proxy manager
|
||||
@@ -348,30 +375,32 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received disconnect message")
|
||||
|
||||
if pingStopChan != nil {
|
||||
// Stop the ping check
|
||||
close(pingStopChan)
|
||||
pingStopChan = nil
|
||||
}
|
||||
|
||||
// Stop proxy manager if running
|
||||
if pm != nil {
|
||||
pm.Stop()
|
||||
pm = nil
|
||||
}
|
||||
|
||||
// Close WireGuard device if running
|
||||
// Close WireGuard device first - this will automatically close the TUN device
|
||||
if dev != nil {
|
||||
dev.Close()
|
||||
dev = nil
|
||||
}
|
||||
|
||||
// Close TUN/netstack if running
|
||||
// Clear references but don't manually close since dev.Close() already did it
|
||||
if tnet != nil {
|
||||
tnet = nil
|
||||
}
|
||||
if tun != nil {
|
||||
tun.Close()
|
||||
tun = nil
|
||||
tun = nil // Don't call tun.Close() here since dev.Close() already closed it
|
||||
}
|
||||
|
||||
// Stop the ping check
|
||||
close(pingStopChan)
|
||||
|
||||
// Mark as disconnected
|
||||
connected = false
|
||||
|
||||
@@ -380,9 +409,12 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
|
||||
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
|
||||
logger.Info("Received ping message")
|
||||
if stopFunc != nil {
|
||||
stopFunc() // stop the ws from sending more requests
|
||||
stopFunc = nil // reset stopFunc to nil to avoid double stopping
|
||||
}
|
||||
|
||||
// Parse the incoming list of exit nodes
|
||||
// Exit nodes is a json
|
||||
var exitNodeData ExitNodeData
|
||||
|
||||
jsonData, err := json.Marshal(msg.Data)
|
||||
@@ -408,8 +440,11 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
}
|
||||
|
||||
results := make([]nodeResult, len(exitNodes))
|
||||
const pingAttempts = 3
|
||||
for i, node := range exitNodes {
|
||||
start := time.Now()
|
||||
var totalLatency time.Duration
|
||||
var lastErr error
|
||||
successes := 0
|
||||
client := &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
@@ -420,89 +455,54 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
if !strings.HasSuffix(url, "/ping") {
|
||||
url = strings.TrimRight(url, "/") + "/ping"
|
||||
}
|
||||
resp, err := client.Get(url)
|
||||
latency := time.Since(start)
|
||||
if err != nil {
|
||||
logger.Warn("Failed to ping exit node %s (%s): %v", node.ID, url, err)
|
||||
results[i] = nodeResult{Node: node, Latency: latency, Err: err}
|
||||
continue
|
||||
}
|
||||
resp.Body.Close()
|
||||
results[i] = nodeResult{Node: node, Latency: latency, Err: nil}
|
||||
// logger.Info("Exit node %s latency: %v", node.Name, latency)
|
||||
}
|
||||
|
||||
// we will need to tweak these
|
||||
const (
|
||||
latencyPenaltyExponent = 1.5 // make latency matter more
|
||||
lastNodeScoreBoost = 1.10 // 10% preference for the last used node
|
||||
scoreTolerancePercent = 5.0 // allow last node if within 5% of best score
|
||||
)
|
||||
|
||||
var bestNode *ExitNode
|
||||
var bestScore float64 = -1e12
|
||||
var bestLatency time.Duration = 1e12
|
||||
|
||||
type ExitNodeScore struct {
|
||||
Node ExitNode
|
||||
Score float64
|
||||
Latency time.Duration
|
||||
}
|
||||
var candidateNodes []ExitNodeScore
|
||||
|
||||
for _, res := range results {
|
||||
if res.Err != nil || res.Node.Weight <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
latencyMs := float64(res.Latency.Milliseconds())
|
||||
score := res.Node.Weight / math.Pow(latencyMs, latencyPenaltyExponent)
|
||||
|
||||
// slight boost if this is the last used node
|
||||
if res.Node.WasPreviouslyConnected == true {
|
||||
score *= lastNodeScoreBoost
|
||||
}
|
||||
|
||||
logger.Info("Exit node %s with score: %.2f (latency: %dms, weight: %.2f)", res.Node.Name, score, res.Latency.Milliseconds(), res.Node.Weight)
|
||||
|
||||
candidateNodes = append(candidateNodes, ExitNodeScore{Node: res.Node, Score: score, Latency: res.Latency})
|
||||
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
bestLatency = res.Latency
|
||||
bestNode = &res.Node
|
||||
} else if score == bestScore && res.Latency < bestLatency {
|
||||
bestLatency = res.Latency
|
||||
bestNode = &res.Node
|
||||
}
|
||||
}
|
||||
|
||||
// check if last used node is close enough in score
|
||||
for _, cand := range candidateNodes {
|
||||
if cand.Node.WasPreviouslyConnected {
|
||||
if bestScore-cand.Score <= bestScore*(scoreTolerancePercent/100.0) {
|
||||
logger.Info("Sticking with last used exit node: %s (%s), score close enough to best", cand.Node.Name, cand.Node.Endpoint)
|
||||
bestNode = &cand.Node
|
||||
for j := 0; j < pingAttempts; j++ {
|
||||
start := time.Now()
|
||||
resp, err := client.Get(url)
|
||||
latency := time.Since(start)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
logger.Warn("Failed to ping exit node %d (%s) attempt %d: %v", node.ID, url, j+1, err)
|
||||
continue
|
||||
}
|
||||
break
|
||||
resp.Body.Close()
|
||||
totalLatency += latency
|
||||
successes++
|
||||
}
|
||||
var avgLatency time.Duration
|
||||
if successes > 0 {
|
||||
avgLatency = totalLatency / time.Duration(successes)
|
||||
}
|
||||
if successes == 0 {
|
||||
results[i] = nodeResult{Node: node, Latency: 0, Err: lastErr}
|
||||
} else {
|
||||
results[i] = nodeResult{Node: node, Latency: avgLatency, Err: nil}
|
||||
}
|
||||
}
|
||||
|
||||
if bestNode == nil {
|
||||
logger.Warn("No suitable exit node found")
|
||||
return
|
||||
// Prepare data to send to the cloud for selection
|
||||
var pingResults []ExitNodePingResult
|
||||
for _, res := range results {
|
||||
errMsg := ""
|
||||
if res.Err != nil {
|
||||
errMsg = res.Err.Error()
|
||||
}
|
||||
pingResults = append(pingResults, ExitNodePingResult{
|
||||
ExitNodeID: res.Node.ID,
|
||||
LatencyMs: res.Latency.Milliseconds(),
|
||||
Weight: res.Node.Weight,
|
||||
Error: errMsg,
|
||||
Name: res.Node.Name,
|
||||
Endpoint: res.Node.Endpoint,
|
||||
WasPreviouslyConnected: res.Node.WasPreviouslyConnected,
|
||||
})
|
||||
}
|
||||
|
||||
logger.Info("Selected exit node: %s (%s)", bestNode.Name, bestNode.Endpoint)
|
||||
|
||||
err = client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"exitNodeId": bestNode.ID,
|
||||
})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send registration message: %v", err)
|
||||
return
|
||||
}
|
||||
// Send the ping results to the cloud for selection
|
||||
stopFunc = client.SendMessageInterval("newt/wg/register", map[string]interface{}{
|
||||
"publicKey": publicKey.String(),
|
||||
"pingResults": pingResults,
|
||||
}, 1*time.Second)
|
||||
logger.Info("Sent exit node ping results to cloud for selection")
|
||||
})
|
||||
|
||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
||||
@@ -648,10 +648,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
logger.Debug("Public key: %s", publicKey)
|
||||
|
||||
// request from the server the list of nodes to ping at newt/ping/request
|
||||
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send ping request: %v", err)
|
||||
}
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||
|
||||
if wgService != nil {
|
||||
wgService.LoadRemoteConfig()
|
||||
@@ -699,75 +696,59 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, stopChan chan struct{}) {
|
||||
initialInterval := 10 * time.Second
|
||||
maxInterval := 60 * time.Second
|
||||
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
|
||||
initialInterval := pingInterval
|
||||
maxInterval := 3 * time.Second
|
||||
currentInterval := initialInterval
|
||||
consecutiveFailures := 0
|
||||
connectionLost := false
|
||||
ticker := time.NewTicker(currentInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
pingStopChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(currentInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
_, err := ping(tnet, serverIP)
|
||||
if err != nil {
|
||||
consecutiveFailures++
|
||||
|
||||
// Check if this is the first failure (connection just lost)
|
||||
if !connectionLost {
|
||||
connectionLost = true
|
||||
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
||||
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
|
||||
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
|
||||
}
|
||||
|
||||
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
|
||||
consecutiveFailures, err)
|
||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
||||
|
||||
// Increase interval if we have consistent failures, with a maximum cap
|
||||
if consecutiveFailures >= 5 && currentInterval < maxInterval {
|
||||
// Increase by 50% each time, up to the maximum
|
||||
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
|
||||
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
||||
if !connectionLost {
|
||||
connectionLost = true
|
||||
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||
}
|
||||
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
||||
if currentInterval > maxInterval {
|
||||
currentInterval = maxInterval
|
||||
}
|
||||
ticker.Reset(currentInterval)
|
||||
logger.Debug("Increased ping check interval to %v due to consecutive failures",
|
||||
currentInterval)
|
||||
|
||||
// Restart the connection flow
|
||||
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
|
||||
if err != nil {
|
||||
logger.Error("Failed to send ping request: %v", err)
|
||||
}
|
||||
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
||||
}
|
||||
} else {
|
||||
// Check if connection was previously lost and is now restored
|
||||
if connectionLost {
|
||||
connectionLost = false
|
||||
logger.Info("Connection to server restored!")
|
||||
}
|
||||
|
||||
// On success, if we've backed off, gradually return to normal interval
|
||||
if currentInterval > initialInterval {
|
||||
currentInterval = time.Duration(float64(currentInterval) * 0.8)
|
||||
if currentInterval < initialInterval {
|
||||
currentInterval = initialInterval
|
||||
}
|
||||
ticker.Reset(currentInterval)
|
||||
logger.Info("Decreased ping check interval to %v after successful ping",
|
||||
currentInterval)
|
||||
logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
|
||||
}
|
||||
consecutiveFailures = 0
|
||||
}
|
||||
case <-stopChan:
|
||||
case <-pingStopChan:
|
||||
logger.Info("Stopping ping check")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return pingStopChan
|
||||
}
|
||||
|
||||
8
util.go
8
util.go
@@ -44,7 +44,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
|
||||
|
||||
requestPing := icmp.Echo{
|
||||
Seq: rand.Intn(1 << 16),
|
||||
Data: []byte("gopher burrow"),
|
||||
Data: []byte("f"),
|
||||
}
|
||||
|
||||
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
|
||||
@@ -52,7 +52,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
|
||||
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
|
||||
}
|
||||
|
||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil {
|
||||
return 0, fmt.Errorf("failed to set read deadline: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,12 +84,14 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
|
||||
|
||||
latency := time.Since(start)
|
||||
|
||||
logger.Debug("Ping to %s successful, latency: %v", dst, latency)
|
||||
|
||||
return latency, nil
|
||||
}
|
||||
|
||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
||||
const (
|
||||
initialMaxAttempts = 15
|
||||
initialMaxAttempts = 5
|
||||
initialRetryDelay = 2 * time.Second
|
||||
maxRetryDelay = 60 * time.Second // Cap the maximum delay
|
||||
)
|
||||
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"software.sslmate.com/src/go-pkcs12"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"software.sslmate.com/src/go-pkcs12"
|
||||
|
||||
"github.com/fosrl/newt/logger"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
@@ -126,6 +127,32 @@ func (c *Client) SendMessage(messageType string, data interface{}) error {
|
||||
return c.conn.WriteJSON(msg)
|
||||
}
|
||||
|
||||
func (c *Client) SendMessageInterval(messageType string, data interface{}, interval time.Duration) (stop func()) {
|
||||
stopChan := make(chan struct{})
|
||||
go func() {
|
||||
err := c.SendMessage(messageType, data) // Send immediately
|
||||
if err != nil {
|
||||
logger.Error("Failed to send initial message: %v", err)
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
err = c.SendMessage(messageType, data)
|
||||
if err != nil {
|
||||
logger.Error("Failed to send message: %v", err)
|
||||
}
|
||||
case <-stopChan:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return func() {
|
||||
close(stopChan)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterHandler registers a handler for a specific message type
|
||||
func (c *Client) RegisterHandler(messageType string, handler MessageHandler) {
|
||||
c.handlersMux.Lock()
|
||||
|
||||
Reference in New Issue
Block a user