Reorg and add timeout

This commit is contained in:
Owen
2025-06-19 15:59:21 -04:00
parent 1c75eb3bee
commit bb1318278a
2 changed files with 77 additions and 63 deletions

72
main.go
View File

@@ -86,6 +86,7 @@ var (
tlsPrivateKey string
dockerSocket string
pingInterval = 1 * time.Second
pingTimeout = 2 * time.Second
publicKey wgtypes.Key
pingStopChan chan struct{}
stopFunc func()
@@ -107,6 +108,7 @@ func main() {
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
dockerSocket = os.Getenv("DOCKER_SOCKET")
pingIntervalStr := os.Getenv("PING_INTERVAL")
pingTimeoutStr := os.Getenv("PING_TIMEOUT")
if endpoint == "" {
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
@@ -146,6 +148,9 @@ func main() {
if pingIntervalStr == "" {
flag.StringVar(&pingIntervalStr, "ping-interval", "1s", "Interval for pinging the server (default 1s)")
}
if pingTimeoutStr == "" {
flag.StringVar(&pingTimeoutStr, "ping-timeout", "2s", " Timeout for each ping (default 2s)")
}
if pingIntervalStr != "" {
pingInterval, err = time.ParseDuration(pingIntervalStr)
@@ -155,6 +160,14 @@ func main() {
}
}
if pingTimeoutStr != "" {
pingTimeout, err = time.ParseDuration(pingTimeoutStr)
if err != nil {
fmt.Printf("Invalid PING_TIMEOUT value: %s, using default 2 seconds\n", pingTimeoutStr)
pingTimeout = 2 * time.Second
}
}
// do a --version check
version := flag.Bool("version", false, "Print the version")
@@ -336,7 +349,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Info("WireGuard device created. Lets ping the server now...")
// Even if pingWithRetry returns an error, it will continue trying in the background
_ = pingWithRetry(tnet, wgData.ServerIP)
_ = pingWithRetry(tnet, wgData.ServerIP, pingTimeout)
// Always mark as connected and start the proxy manager regardless of initial ping result
// as the pings will continue in the background
@@ -695,60 +708,3 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
logger.Info("Exiting...")
os.Exit(0)
}
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
initialInterval := pingInterval
maxInterval := 3 * time.Second
currentInterval := initialInterval
consecutiveFailures := 0
connectionLost := false
pingStopChan := make(chan struct{})
go func() {
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, err := ping(tnet, serverIP)
if err != nil {
consecutiveFailures++
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
if consecutiveFailures >= 3 && currentInterval < maxInterval {
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
}
currentInterval = time.Duration(float64(currentInterval) * 1.5)
if currentInterval > maxInterval {
currentInterval = maxInterval
}
ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
}
} else {
if connectionLost {
connectionLost = false
logger.Info("Connection to server restored!")
}
if currentInterval > initialInterval {
currentInterval = time.Duration(float64(currentInterval) * 0.8)
if currentInterval < initialInterval {
currentInterval = initialInterval
}
ticker.Reset(currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
}
consecutiveFailures = 0
}
case <-pingStopChan:
logger.Info("Stopping ping check")
return
}
}
}()
return pingStopChan
}

68
util.go
View File

@@ -13,6 +13,7 @@ import (
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket"
"golang.org/x/exp/rand"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
@@ -34,7 +35,7 @@ func fixKey(key string) string {
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
func ping(tnet *netstack.Net, dst string, timeout time.Duration) (time.Duration, error) {
logger.Debug("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
if err != nil {
@@ -52,7 +53,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 2)); err != nil {
if err := socket.SetReadDeadline(time.Now().Add(timeout)); err != nil {
return 0, fmt.Errorf("failed to set read deadline: %w", err)
}
@@ -89,7 +90,7 @@ func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
return latency, nil
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
func pingWithRetry(tnet *netstack.Net, dst string, timeout time.Duration) error {
const (
initialMaxAttempts = 5
initialRetryDelay = 2 * time.Second
@@ -101,7 +102,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
// First try with the initial parameters
logger.Info("Ping attempt %d", attempt)
if latency, err := ping(tnet, dst); err == nil {
if latency, err := ping(tnet, dst, timeout); err == nil {
// Successful ping
logger.Info("Ping latency: %v", latency)
@@ -118,7 +119,7 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
for {
logger.Info("Ping attempt %d", attempt)
if latency, err := ping(tnet, dst); err != nil {
if latency, err := ping(tnet, dst, timeout); err != nil {
logger.Warn("Ping attempt %d failed: %v", attempt, err)
// Increase delay after certain thresholds but cap it
@@ -146,6 +147,63 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
return fmt.Errorf("initial ping attempts failed, continuing in background")
}
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client) chan struct{} {
initialInterval := pingInterval
maxInterval := 3 * time.Second
currentInterval := initialInterval
consecutiveFailures := 0
connectionLost := false
pingStopChan := make(chan struct{})
go func() {
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
_, err := ping(tnet, serverIP, pingTimeout)
if err != nil {
consecutiveFailures++
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
if consecutiveFailures >= 3 && currentInterval < maxInterval {
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
stopFunc = client.SendMessageInterval("newt/ping/request", map[string]interface{}{}, 3*time.Second)
}
currentInterval = time.Duration(float64(currentInterval) * 1.5)
if currentInterval > maxInterval {
currentInterval = maxInterval
}
ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures", currentInterval)
}
} else {
if connectionLost {
connectionLost = false
logger.Info("Connection to server restored!")
}
if currentInterval > initialInterval {
currentInterval = time.Duration(float64(currentInterval) * 0.8)
if currentInterval < initialInterval {
currentInterval = initialInterval
}
ticker.Reset(currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping", currentInterval)
}
consecutiveFailures = 0
}
case <-pingStopChan:
logger.Info("Stopping ping check")
return
}
}
}()
return pingStopChan
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":