Update ping check

This commit is contained in:
Owen
2025-06-18 22:54:13 -04:00
parent b397016da8
commit 4b64b04603
2 changed files with 483 additions and 449 deletions

580
main.go
View File

@@ -1,19 +1,13 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"math"
"math/rand"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
"runtime"
"strconv"
@@ -28,8 +22,6 @@ import (
"github.com/fosrl/newt/wg"
"github.com/fosrl/newt/wgtester"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun"
@@ -60,313 +52,13 @@ type ExitNodeData struct {
// ExitNode represents an exit node with an ID, endpoint, and weight.
type ExitNode struct {
ID int `json:"exitNodeId"`
Name string `json:"exitNodeName"`
Endpoint string `json:"endpoint"`
Weight float64 `json:"weight"`
ID int `json:"exitNodeId"`
Name string `json:"exitNodeName"`
Endpoint string `json:"endpoint"`
Weight float64 `json:"weight"`
WasPreviouslyConnected bool `json:"wasPreviouslyConnected"`
}
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
logger.Fatal("Error decoding base64: %v", err)
}
// Convert to hex
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
logger.Debug("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
if err != nil {
return 0, fmt.Errorf("failed to create ICMP socket: %w", err)
}
defer socket.Close()
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
}
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
if err != nil {
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
return 0, fmt.Errorf("failed to set read deadline: %w", err)
}
start := time.Now()
_, err = socket.Write(icmpBytes)
if err != nil {
return 0, fmt.Errorf("failed to write ICMP packet: %w", err)
}
n, err := socket.Read(icmpBytes[:])
if err != nil {
return 0, fmt.Errorf("failed to read ICMP packet: %w", err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
return 0, fmt.Errorf("failed to parse ICMP packet: %w", err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
return 0, fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
return 0, fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
}
latency := time.Since(start)
return latency, nil
}
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
initialInterval := 10 * time.Second
maxInterval := 60 * time.Second
currentInterval := initialInterval
consecutiveFailures := 0
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
go func() {
for {
select {
case <-ticker.C:
_, err := ping(tnet, serverIP)
if err != nil {
consecutiveFailures++
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
consecutiveFailures, err)
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
// Increase interval if we have consistent failures, with a maximum cap
if consecutiveFailures >= 3 && currentInterval < maxInterval {
// Increase by 50% each time, up to the maximum
currentInterval = time.Duration(float64(currentInterval) * 1.5)
if currentInterval > maxInterval {
currentInterval = maxInterval
}
ticker.Reset(currentInterval)
logger.Info("Increased ping check interval to %v due to consecutive failures",
currentInterval)
}
} else {
// On success, if we've backed off, gradually return to normal interval
if currentInterval > initialInterval {
currentInterval = time.Duration(float64(currentInterval) * 0.8)
if currentInterval < initialInterval {
currentInterval = initialInterval
}
ticker.Reset(currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping",
currentInterval)
}
consecutiveFailures = 0
}
case <-stopChan:
logger.Info("Stopping ping check")
return
}
}
}()
}
// Function to track connection status and trigger reconnection as needed
func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) {
const checkInterval = 30 * time.Second
connectionLost := false
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Try a ping to see if connection is alive
_, err := ping(tnet, serverIP)
if err != nil && !connectionLost {
// We just lost connection
connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
// Notify the user they might need to check their network
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
} else if err == nil && connectionLost {
// Connection has been restored
connectionLost = false
logger.Info("Connection to server restored!")
// Tell the server we're back
err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": privateKey.PublicKey().String(),
})
if err != nil {
logger.Error("Failed to send registration message after reconnection: %v", err)
} else {
logger.Info("Successfully re-registered with server after reconnection")
}
}
}
}
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
const (
initialMaxAttempts = 15
initialRetryDelay = 2 * time.Second
maxRetryDelay = 60 * time.Second // Cap the maximum delay
)
attempt := 1
retryDelay := initialRetryDelay
// First try with the initial parameters
logger.Info("Ping attempt %d", attempt)
if latency, err := ping(tnet, dst); err == nil {
// Successful ping
logger.Info("Ping latency: %v", latency)
logger.Info("Tunnel connection to server established successfully!")
return nil
} else {
logger.Warn("Ping attempt %d failed: %v", attempt, err)
}
// Start a goroutine that will attempt pings indefinitely with increasing delays
go func() {
attempt = 2 // Continue from attempt 2
for {
logger.Info("Ping attempt %d", attempt)
if latency, err := ping(tnet, dst); err != nil {
logger.Warn("Ping attempt %d failed: %v", attempt, err)
// Increase delay after certain thresholds but cap it
if attempt%5 == 0 && retryDelay < maxRetryDelay {
retryDelay = time.Duration(float64(retryDelay) * 1.5)
if retryDelay > maxRetryDelay {
retryDelay = maxRetryDelay
}
logger.Info("Increasing ping retry delay to %v", retryDelay)
}
time.Sleep(retryDelay)
attempt++
} else {
// Successful ping
logger.Info("Ping succeeded after %d attempts", attempt)
logger.Info("Ping latency: %v", latency)
logger.Info("Tunnel connection to server established successfully!")
return
}
}
}()
// Return an error for the first batch of attempts (to maintain compatibility with existing code)
return fmt.Errorf("initial ping attempts failed, continuing in background")
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return logger.DEBUG
case "INFO":
return logger.INFO
case "WARN":
return logger.WARN
case "ERROR":
return logger.ERROR
case "FATAL":
return logger.FATAL
default:
return logger.INFO // default to INFO if invalid level provided
}
}
func mapToWireGuardLogLevel(level logger.LogLevel) int {
switch level {
case logger.DEBUG:
return device.LogLevelVerbose
// case logger.INFO:
// return device.LogLevel
case logger.WARN:
return device.LogLevelError
case logger.ERROR, logger.FATAL:
return device.LogLevelSilent
default:
return device.LogLevelSilent
}
}
func resolveDomain(domain string) (string, error) {
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Remove any protocol prefix if present
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
// if there are any trailing slashes, remove them
host = strings.TrimSuffix(host, "/")
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
var (
endpoint string
id string
@@ -524,17 +216,6 @@ func main() {
}
}
client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) {
logger.Info("Received terminate message")
if pm != nil {
pm.Stop()
}
if dev != nil {
dev.Close()
}
client.Close()
})
pingStopChan := make(chan struct{})
defer close(pingStopChan)
@@ -543,10 +224,32 @@ func main() {
logger.Info("Received registration message")
if connected {
logger.Info("Already connected! But I will send a ping anyway...")
// Even if pingWithRetry returns an error, it will continue trying in the background
_ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue
return
// Stop proxy manager if running
if pm != nil {
pm.Stop()
pm = nil
}
// Close WireGuard device if running
if dev != nil {
dev.Close()
dev = nil
}
// Close TUN/netstack if running
if tnet != nil {
tnet = nil
}
if tun != nil {
tun.Close()
tun = nil
}
// Stop the ping check
close(pingStopChan)
// Mark as disconnected
connected = false
}
jsonData, err := json.Marshal(msg.Data)
@@ -612,10 +315,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
// as the pings will continue in the background
if !connected {
logger.Info("Starting ping check")
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
// Start connection monitoring in a separate goroutine
go monitorConnectionStatus(tnet, wgData.ServerIP, client)
startPingCheck(tnet, wgData.ServerIP, client, pingStopChan)
}
// Create proxy manager
@@ -645,6 +345,39 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
})
client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) {
logger.Info("Received disconnect message")
// Stop proxy manager if running
if pm != nil {
pm.Stop()
pm = nil
}
// Close WireGuard device if running
if dev != nil {
dev.Close()
dev = nil
}
// Close TUN/netstack if running
if tnet != nil {
tnet = nil
}
if tun != nil {
tun.Close()
tun = nil
}
// Stop the ping check
close(pingStopChan)
// Mark as disconnected
connected = false
logger.Info("Tunnel destroyed, ready for reconnection")
})
client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) {
logger.Info("Received ping message")
@@ -747,7 +480,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
// check if last used node is close enough in score
for _, cand := range candidateNodes {
if cand.Node.WasPreviouslyConnected {
if bestScore - cand.Score <= bestScore*(scoreTolerancePercent/100.0) {
if bestScore-cand.Score <= bestScore*(scoreTolerancePercent/100.0) {
logger.Info("Sticking with last used exit node: %s (%s), score close enough to best", cand.Node.Name, cand.Node.Endpoint)
bestNode = &cand.Node
}
@@ -966,126 +699,75 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
os.Exit(0)
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return targetData, err
}
func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, stopChan chan struct{}) {
initialInterval := 10 * time.Second
maxInterval := 60 * time.Second
currentInterval := initialInterval
consecutiveFailures := 0
connectionLost := false
ticker := time.NewTicker(currentInterval)
defer ticker.Stop()
if err := json.Unmarshal(jsonData, &targetData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
if updownScript != "" {
newTarget, err := executeUpdownScript(action, proto, target)
go func() {
for {
select {
case <-ticker.C:
_, err := ping(tnet, serverIP)
if err != nil {
logger.Warn("Updown script error: %v", err)
} else if newTarget != "" {
processedTarget = newTarget
consecutiveFailures++
// Check if this is the first failure (connection just lost)
if !connectionLost {
connectionLost = true
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
}
logger.Warn("Periodic ping failed (%d consecutive failures): %v",
consecutiveFailures, err)
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
// Increase interval if we have consistent failures, with a maximum cap
if consecutiveFailures >= 5 && currentInterval < maxInterval {
// Increase by 50% each time, up to the maximum
currentInterval = time.Duration(float64(currentInterval) * 1.5)
if currentInterval > maxInterval {
currentInterval = maxInterval
}
ticker.Reset(currentInterval)
logger.Debug("Increased ping check interval to %v due to consecutive failures",
currentInterval)
// Restart the connection flow
err := client.SendMessage("newt/ping/request", map[string]interface{}{})
if err != nil {
logger.Error("Failed to send ping request: %v", err)
}
}
} else {
// Check if connection was previously lost and is now restored
if connectionLost {
connectionLost = false
logger.Info("Connection to server restored!")
}
// On success, if we've backed off, gradually return to normal interval
if currentInterval > initialInterval {
currentInterval = time.Duration(float64(currentInterval) * 0.8)
if currentInterval < initialInterval {
currentInterval = initialInterval
}
ticker.Reset(currentInterval)
logger.Info("Decreased ping check interval to %v after successful ping",
currentInterval)
}
consecutiveFailures = 0
}
}
// Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
// Ignore "target not found" errors as this is expected for new targets
if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
}
}
// Add the new target
pm.AddTarget(proto, tunnelIP, port, processedTarget)
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided
if updownScript != "" {
_, err := executeUpdownScript(action, proto, target)
if err != nil {
logger.Warn("Updown script error: %v", err)
}
}
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
case <-stopChan:
logger.Info("Stopping ping check")
return
}
}
}
return nil
}
func executeUpdownScript(action, proto, target string) (string, error) {
if updownScript == "" {
return target, nil
}
// Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py")
parts := strings.Fields(updownScript)
if len(parts) == 0 {
return target, fmt.Errorf("invalid updown script command")
}
var cmd *exec.Cmd
if len(parts) == 1 {
// If it's a single executable
logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target)
cmd = exec.Command(parts[0], action, proto, target)
} else {
// If it includes interpreter and script
args := append(parts[1:], action, proto, target)
logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target)
cmd = exec.Command(parts[0], args...)
}
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("updown script execution failed (exit code %d): %s",
exitErr.ExitCode(), string(exitErr.Stderr))
}
return "", fmt.Errorf("updown script execution failed: %v", err)
}
// If the script returns a new target, use it
newTarget := strings.TrimSpace(string(output))
if newTarget != "" {
logger.Info("Updown script returned new target: %s", newTarget)
return newTarget, nil
}
return target, nil
}()
}

352
util.go Normal file
View File

@@ -0,0 +1,352 @@
package main
import (
"bytes"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"net"
"os/exec"
"strings"
"time"
"github.com/fosrl/newt/logger"
"github.com/fosrl/newt/proxy"
"golang.org/x/exp/rand"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
)
func fixKey(key string) string {
// Remove any whitespace
key = strings.TrimSpace(key)
// Decode from base64
decoded, err := base64.StdEncoding.DecodeString(key)
if err != nil {
logger.Fatal("Error decoding base64: %v", err)
}
// Convert to hex
return hex.EncodeToString(decoded)
}
func ping(tnet *netstack.Net, dst string) (time.Duration, error) {
logger.Debug("Pinging %s", dst)
socket, err := tnet.Dial("ping4", dst)
if err != nil {
return 0, fmt.Errorf("failed to create ICMP socket: %w", err)
}
defer socket.Close()
requestPing := icmp.Echo{
Seq: rand.Intn(1 << 16),
Data: []byte("gopher burrow"),
}
icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil)
if err != nil {
return 0, fmt.Errorf("failed to marshal ICMP message: %w", err)
}
if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
return 0, fmt.Errorf("failed to set read deadline: %w", err)
}
start := time.Now()
_, err = socket.Write(icmpBytes)
if err != nil {
return 0, fmt.Errorf("failed to write ICMP packet: %w", err)
}
n, err := socket.Read(icmpBytes[:])
if err != nil {
return 0, fmt.Errorf("failed to read ICMP packet: %w", err)
}
replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n])
if err != nil {
return 0, fmt.Errorf("failed to parse ICMP packet: %w", err)
}
replyPing, ok := replyPacket.Body.(*icmp.Echo)
if !ok {
return 0, fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body)
}
if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq {
return 0, fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q",
replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data)
}
latency := time.Since(start)
return latency, nil
}
func pingWithRetry(tnet *netstack.Net, dst string) error {
const (
initialMaxAttempts = 15
initialRetryDelay = 2 * time.Second
maxRetryDelay = 60 * time.Second // Cap the maximum delay
)
attempt := 1
retryDelay := initialRetryDelay
// First try with the initial parameters
logger.Info("Ping attempt %d", attempt)
if latency, err := ping(tnet, dst); err == nil {
// Successful ping
logger.Info("Ping latency: %v", latency)
logger.Info("Tunnel connection to server established successfully!")
return nil
} else {
logger.Warn("Ping attempt %d failed: %v", attempt, err)
}
// Start a goroutine that will attempt pings indefinitely with increasing delays
go func() {
attempt = 2 // Continue from attempt 2
for {
logger.Info("Ping attempt %d", attempt)
if latency, err := ping(tnet, dst); err != nil {
logger.Warn("Ping attempt %d failed: %v", attempt, err)
// Increase delay after certain thresholds but cap it
if attempt%5 == 0 && retryDelay < maxRetryDelay {
retryDelay = time.Duration(float64(retryDelay) * 1.5)
if retryDelay > maxRetryDelay {
retryDelay = maxRetryDelay
}
logger.Info("Increasing ping retry delay to %v", retryDelay)
}
time.Sleep(retryDelay)
attempt++
} else {
// Successful ping
logger.Info("Ping succeeded after %d attempts", attempt)
logger.Info("Ping latency: %v", latency)
logger.Info("Tunnel connection to server established successfully!")
return
}
}
}()
// Return an error for the first batch of attempts (to maintain compatibility with existing code)
return fmt.Errorf("initial ping attempts failed, continuing in background")
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
return logger.DEBUG
case "INFO":
return logger.INFO
case "WARN":
return logger.WARN
case "ERROR":
return logger.ERROR
case "FATAL":
return logger.FATAL
default:
return logger.INFO // default to INFO if invalid level provided
}
}
func mapToWireGuardLogLevel(level logger.LogLevel) int {
switch level {
case logger.DEBUG:
return device.LogLevelVerbose
// case logger.INFO:
// return device.LogLevel
case logger.WARN:
return device.LogLevelError
case logger.ERROR, logger.FATAL:
return device.LogLevelSilent
default:
return device.LogLevelSilent
}
}
func resolveDomain(domain string) (string, error) {
// Check if there's a port in the domain
host, port, err := net.SplitHostPort(domain)
if err != nil {
// No port found, use the domain as is
host = domain
port = ""
}
// Remove any protocol prefix if present
if strings.HasPrefix(host, "http://") {
host = strings.TrimPrefix(host, "http://")
} else if strings.HasPrefix(host, "https://") {
host = strings.TrimPrefix(host, "https://")
}
// if there are any trailing slashes, remove them
host = strings.TrimSuffix(host, "/")
// Lookup IP addresses
ips, err := net.LookupIP(host)
if err != nil {
return "", fmt.Errorf("DNS lookup failed: %v", err)
}
if len(ips) == 0 {
return "", fmt.Errorf("no IP addresses found for domain %s", host)
}
// Get the first IPv4 address if available
var ipAddr string
for _, ip := range ips {
if ipv4 := ip.To4(); ipv4 != nil {
ipAddr = ipv4.String()
break
}
}
// If no IPv4 found, use the first IP (might be IPv6)
if ipAddr == "" {
ipAddr = ips[0].String()
}
// Add port back if it existed
if port != "" {
ipAddr = net.JoinHostPort(ipAddr, port)
}
return ipAddr, nil
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
if updownScript != "" {
newTarget, err := executeUpdownScript(action, proto, target)
if err != nil {
logger.Warn("Updown script error: %v", err)
} else if newTarget != "" {
processedTarget = newTarget
}
}
// Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
// Ignore "target not found" errors as this is expected for new targets
if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
}
}
// Add the new target
pm.AddTarget(proto, tunnelIP, port, processedTarget)
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided
if updownScript != "" {
_, err := executeUpdownScript(action, proto, target)
if err != nil {
logger.Warn("Updown script error: %v", err)
}
}
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
}
}
}
return nil
}
func executeUpdownScript(action, proto, target string) (string, error) {
if updownScript == "" {
return target, nil
}
// Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py")
parts := strings.Fields(updownScript)
if len(parts) == 0 {
return target, fmt.Errorf("invalid updown script command")
}
var cmd *exec.Cmd
if len(parts) == 1 {
// If it's a single executable
logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target)
cmd = exec.Command(parts[0], action, proto, target)
} else {
// If it includes interpreter and script
args := append(parts[1:], action, proto, target)
logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target)
cmd = exec.Command(parts[0], args...)
}
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("updown script execution failed (exit code %d): %s",
exitErr.ExitCode(), string(exitErr.Stderr))
}
return "", fmt.Errorf("updown script execution failed: %v", err)
}
// If the script returns a new target, use it
newTarget := strings.TrimSpace(string(output))
if newTarget != "" {
logger.Info("Updown script returned new target: %s", newTarget)
return newTarget, nil
}
return target, nil
}