mirror of
https://github.com/fosrl/newt.git
synced 2026-03-10 20:56:40 +00:00
Reorg and add timeout
This commit is contained in:
72
main.go
72
main.go
@@ -86,6 +86,7 @@ var (
|
|||||||
tlsPrivateKey string
|
tlsPrivateKey string
|
||||||
dockerSocket string
|
dockerSocket string
|
||||||
pingInterval = 1 * time.Second
|
pingInterval = 1 * time.Second
|
||||||
|
pingTimeout = 2 * time.Second
|
||||||
publicKey wgtypes.Key
|
publicKey wgtypes.Key
|
||||||
pingStopChan chan struct{}
|
pingStopChan chan struct{}
|
||||||
stopFunc func()
|
stopFunc func()
|
||||||
@@ -107,6 +108,7 @@ func main() {
|
|||||||
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
||||||
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
||||||
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
pingIntervalStr := os.Getenv("PING_INTERVAL")
|
||||||
|
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||||
@@ -146,6 +148,9 @@ func main() {
|
|||||||
if pingIntervalStr == "" {
|
if pingIntervalStr == "" {
|
||||||
flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)")
|
flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)")
|
||||||
}
|
}
|
||||||
|
if pingTimeoutStr == "" {
|
||||||
|
flag.StringVar(&pingTimeoutStr, "ping-timeout", "2s", " Timeout for each ping (default 2s)")
|
||||||
|
}
|
||||||
|
|
||||||
if pingIntervalStr != "" {
|
if pingIntervalStr != "" {
|
||||||
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
pingInterval, err = time.ParseDuration(pingIntervalStr)
|
||||||
@@ -155,6 +160,14 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pingTimeoutStr != "" {
|
||||||
|
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 2 seconds\n", pingTimeoutStr)
|
||||||
|
pingTimeout = 2 * time.Second
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// do a --version check
|
// do a --version check
|
||||||
version := flag.Bool("version", false, "Print the version")
|
version := flag.Bool("version", false, "Print the version")
|
||||||
|
|
||||||
@@ -336,7 +349,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||||
|
|
||||||
// Even if pingWithRetry returns an error, it will continue trying in the background
|
// Even if pingWithRetry returns an error, it will continue trying in the background
|
||||||
_ = pingWithRetry(tnet, wgData.ServerIP)
|
_ = pingWithRetry(tnet, wgData.ServerIP, pingTimeout)
|
||||||
|
|
||||||
// Always mark as connected and start the proxy manager regardless of initial ping result
|
// Always mark as connected and start the proxy manager regardless of initial ping result
|
||||||
// as the pings will continue in the background
|
// as the pings will continue in the background
|
||||||
@@ -695,60 +708,3 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
logger.Info("Exiting...")
|
logger.Info("Exiting...")
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
|
|
||||||
initialInterval := pingInterval
|
|
||||||
maxInterval := 3 * time.Second
|
|
||||||
currentInterval := initialInterval
|
|
||||||
consecutiveFailures := 0
|
|
||||||
connectionLost := false
|
|
||||||
|
|
||||||
pingStopChan := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(currentInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ticker.C:
|
|
||||||
_, err := ping(tnet, serverIP)
|
|
||||||
if err != nil {
|
|
||||||
consecutiveFailures++
|
|
||||||
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
|
|
||||||
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
|
||||||
if !connectionLost {
|
|
||||||
connectionLost = true
|
|
||||||
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
|
||||||
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
|
||||||
}
|
|
||||||
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
|
||||||
if currentInterval > maxInterval {
|
|
||||||
currentInterval = maxInterval
|
|
||||||
}
|
|
||||||
ticker.Reset(currentInterval)
|
|
||||||
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if connectionLost {
|
|
||||||
connectionLost = false
|
|
||||||
logger.Info("Connection to server restored!")
|
|
||||||
}
|
|
||||||
if currentInterval > initialInterval {
|
|
||||||
currentInterval = time.Duration(float64(currentInterval) * 0.8)
|
|
||||||
if currentInterval < initialInterval {
|
|
||||||
currentInterval = initialInterval
|
|
||||||
}
|
|
||||||
ticker.Reset(currentInterval)
|
|
||||||
logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
|
|
||||||
}
|
|
||||||
consecutiveFailures = 0
|
|
||||||
}
|
|
||||||
case <-pingStopChan:
|
|
||||||
logger.Info("Stopping ping check")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return pingStopChan
|
|
||||||
}
|
|
||||||
|
|||||||
68
util.go
68
util.go
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"github.com/fosrl/newt/logger"
|
"github.com/fosrl/newt/logger"
|
||||||
"github.com/fosrl/newt/proxy"
|
"github.com/fosrl/newt/proxy"
|
||||||
|
"github.com/fosrl/newt/websocket"
|
||||||
"golang.org/x/exp/rand"
|
"golang.org/x/exp/rand"
|
||||||
"golang.org/x/net/icmp"
|
"golang.org/x/net/icmp"
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
@@ -34,7 +35,7 @@ func fixKey(key string) string {
|
|||||||
return hex.EncodeToString(decoded)
|
return hex.EncodeToString(decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
|
func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) {
|
||||||
logger.Debug("Pinging %s", dst)
|
logger.Debug("Pinging %s", dst)
|
||||||
socket, err := tnet.Dial("ping4", dst)
|
socket, err := tnet.Dial("ping4", dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -52,7 +53,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
|
|||||||
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
|
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil {
|
if err := socket.SetReadDeadline(time.Now().Add(timeout)); err != nil {
|
||||||
return 0, fmt.Errorf("failed to set read deadline: %w", err)
|
return 0, fmt.Errorf("failed to set read deadline: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,7 +90,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
|
|||||||
return latency, nil
|
return latency, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func pingWithRetry(tnet *netstack.Net, dst string) error {
|
func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) error {
|
||||||
const (
|
const (
|
||||||
initialMaxAttempts = 5
|
initialMaxAttempts = 5
|
||||||
initialRetryDelay = 2 * time.Second
|
initialRetryDelay = 2 * time.Second
|
||||||
@@ -101,7 +102,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|||||||
|
|
||||||
// First try with the initial parameters
|
// First try with the initial parameters
|
||||||
logger.Info("Ping attempt %d", attempt)
|
logger.Info("Ping attempt %d", attempt)
|
||||||
if latency, err := ping(tnet, dst); err == nil {
|
if latency, err := ping(tnet, dst, timeout); err == nil {
|
||||||
// Successful ping
|
// Successful ping
|
||||||
logger.Info("Ping latency: %v", latency)
|
logger.Info("Ping latency: %v", latency)
|
||||||
|
|
||||||
@@ -118,7 +119,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|||||||
for {
|
for {
|
||||||
logger.Info("Ping attempt %d", attempt)
|
logger.Info("Ping attempt %d", attempt)
|
||||||
|
|
||||||
if latency, err := ping(tnet, dst); err != nil {
|
if latency, err := ping(tnet, dst, timeout); err != nil {
|
||||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||||
|
|
||||||
// Increase delay after certain thresholds but cap it
|
// Increase delay after certain thresholds but cap it
|
||||||
@@ -146,6 +147,63 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|||||||
return fmt.Errorf("initial ping attempts failed, continuing in background")
|
return fmt.Errorf("initial ping attempts failed, continuing in background")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
|
||||||
|
initialInterval := pingInterval
|
||||||
|
maxInterval := 3 * time.Second
|
||||||
|
currentInterval := initialInterval
|
||||||
|
consecutiveFailures := 0
|
||||||
|
connectionLost := false
|
||||||
|
|
||||||
|
pingStopChan := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(currentInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ticker.C:
|
||||||
|
_, err := ping(tnet, serverIP, pingTimeout)
|
||||||
|
if err != nil {
|
||||||
|
consecutiveFailures++
|
||||||
|
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
|
||||||
|
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
||||||
|
if !connectionLost {
|
||||||
|
connectionLost = true
|
||||||
|
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
||||||
|
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
|
||||||
|
}
|
||||||
|
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
||||||
|
if currentInterval > maxInterval {
|
||||||
|
currentInterval = maxInterval
|
||||||
|
}
|
||||||
|
ticker.Reset(currentInterval)
|
||||||
|
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if connectionLost {
|
||||||
|
connectionLost = false
|
||||||
|
logger.Info("Connection to server restored!")
|
||||||
|
}
|
||||||
|
if currentInterval > initialInterval {
|
||||||
|
currentInterval = time.Duration(float64(currentInterval) * 0.8)
|
||||||
|
if currentInterval < initialInterval {
|
||||||
|
currentInterval = initialInterval
|
||||||
|
}
|
||||||
|
ticker.Reset(currentInterval)
|
||||||
|
logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
|
||||||
|
}
|
||||||
|
consecutiveFailures = 0
|
||||||
|
}
|
||||||
|
case <-pingStopChan:
|
||||||
|
logger.Info("Stopping ping check")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return pingStopChan
|
||||||
|
}
|
||||||
|
|
||||||
func parseLogLevel(level string) logger.LogLevel {
|
func parseLogLevel(level string) logger.LogLevel {
|
||||||
switch strings.ToUpper(level) {
|
switch strings.ToUpper(level) {
|
||||||
case "DEBUG":
|
case "DEBUG":
|
||||||
|
|||||||
Reference in New Issue
Block a user