From 4b64b0460361e40d853d87adbdaaf750ca429e67 Mon Sep 17 00:00:00 2001 From: Owen Date: Wed, 18 Jun 2025 22:54:13 -0400 Subject: [PATCH] Update ping check --- main.go | 580 +++++++++++++------------------------------------------- util.go | 352 ++++++++++++++++++++++++++++++++++ 2 files changed, 483 insertions(+), 449 deletions(-) create mode 100644 util.go diff --git a/main.go b/main.go index 369e14b..59ad194 100644 --- a/main.go +++ b/main.go @@ -1,19 +1,13 @@ package main import ( - "bytes" - "encoding/base64" - "encoding/hex" "encoding/json" "flag" "fmt" "math" - "math/rand" - "net" "net/http" "net/netip" "os" - "os/exec" "os/signal" "runtime" "strconv" @@ -28,8 +22,6 @@ import ( "github.com/fosrl/newt/wg" "github.com/fosrl/newt/wgtester" - "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/tun" @@ -60,313 +52,13 @@ type ExitNodeData struct { // ExitNode represents an exit node with an ID, endpoint, and weight. type ExitNode struct { - ID int `json:"exitNodeId"` - Name string `json:"exitNodeName"` - Endpoint string `json:"endpoint"` - Weight float64 `json:"weight"` + ID int `json:"exitNodeId"` + Name string `json:"exitNodeName"` + Endpoint string `json:"endpoint"` + Weight float64 `json:"weight"` WasPreviouslyConnected bool `json:"wasPreviouslyConnected"` } -func fixKey(key string) string { - // Remove any whitespace - key = strings.TrimSpace(key) - - // Decode from base64 - decoded, err := base64.StdEncoding.DecodeString(key) - if err != nil { - logger.Fatal("Error decoding base64: %v", err) - } - - // Convert to hex - return hex.EncodeToString(decoded) -} - -func ping(tnet *netstack.Net, dst string) (time.Duration, error) { - logger.Debug("Pinging %s", dst) - socket, err := tnet.Dial("ping4", dst) - if err != nil { - return 0, fmt.Errorf("failed to create ICMP socket: %w", err) - } - defer socket.Close() - - requestPing := icmp.Echo{ - Seq: rand.Intn(1 << 16), - Data: []byte("gopher burrow"), - } - - icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) - if err != nil { - return 0, fmt.Errorf("failed to marshal ICMP message: %w", err) - } - - if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { - return 0, fmt.Errorf("failed to set read deadline: %w", err) - } - - start := time.Now() - _, err = socket.Write(icmpBytes) - if err != nil { - return 0, fmt.Errorf("failed to write ICMP packet: %w", err) - } - - n, err := socket.Read(icmpBytes[:]) - if err != nil { - return 0, fmt.Errorf("failed to read ICMP packet: %w", err) - } - - replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) - if err != nil { - return 0, fmt.Errorf("failed to parse ICMP packet: %w", err) - } - - replyPing, ok := replyPacket.Body.(*icmp.Echo) - if !ok { - return 0, fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body) - } - - if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { - return 0, fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q", - replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data) - } - - latency := time.Since(start) - - return latency, nil -} - -func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) { - initialInterval := 10 * time.Second - maxInterval := 60 * time.Second - currentInterval := initialInterval - consecutiveFailures := 0 - - ticker := time.NewTicker(currentInterval) - defer ticker.Stop() - - go func() { - for { - select { - case <-ticker.C: - _, err := ping(tnet, serverIP) - if err != nil { - consecutiveFailures++ - logger.Warn("Periodic ping failed (%d consecutive failures): %v", - consecutiveFailures, err) - logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?") - - // Increase interval if we have consistent failures, with a maximum cap - if consecutiveFailures >= 3 && currentInterval < maxInterval { - // Increase by 50% each time, up to the maximum - currentInterval = time.Duration(float64(currentInterval) * 1.5) - if currentInterval > maxInterval { - currentInterval = maxInterval - } - ticker.Reset(currentInterval) - logger.Info("Increased ping check interval to %v due to consecutive failures", - currentInterval) - } - } else { - // On success, if we've backed off, gradually return to normal interval - if currentInterval > initialInterval { - currentInterval = time.Duration(float64(currentInterval) * 0.8) - if currentInterval < initialInterval { - currentInterval = initialInterval - } - ticker.Reset(currentInterval) - logger.Info("Decreased ping check interval to %v after successful ping", - currentInterval) - } - consecutiveFailures = 0 - } - case <-stopChan: - logger.Info("Stopping ping check") - return - } - } - }() -} - -// Function to track connection status and trigger reconnection as needed -func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) { - const checkInterval = 30 * time.Second - connectionLost := false - ticker := time.NewTicker(checkInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - // Try a ping to see if connection is alive - _, err := ping(tnet, serverIP) - - if err != nil && !connectionLost { - // We just lost connection - connectionLost = true - logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") - - // Notify the user they might need to check their network - logger.Warn("Please check your internet connection and ensure the Pangolin server is online.") - logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.") - } else if err == nil && connectionLost { - // Connection has been restored - connectionLost = false - logger.Info("Connection to server restored!") - - // Tell the server we're back - err := client.SendMessage("newt/wg/register", map[string]interface{}{ - "publicKey": privateKey.PublicKey().String(), - }) - - if err != nil { - logger.Error("Failed to send registration message after reconnection: %v", err) - } else { - logger.Info("Successfully re-registered with server after reconnection") - } - } - } - } -} - -func pingWithRetry(tnet *netstack.Net, dst string) error { - const ( - initialMaxAttempts = 15 - initialRetryDelay = 2 * time.Second - maxRetryDelay = 60 * time.Second // Cap the maximum delay - ) - - attempt := 1 - retryDelay := initialRetryDelay - - // First try with the initial parameters - logger.Info("Ping attempt %d", attempt) - if latency, err := ping(tnet, dst); err == nil { - // Successful ping - logger.Info("Ping latency: %v", latency) - - logger.Info("Tunnel connection to server established successfully!") - return nil - } else { - logger.Warn("Ping attempt %d failed: %v", attempt, err) - } - - // Start a goroutine that will attempt pings indefinitely with increasing delays - go func() { - attempt = 2 // Continue from attempt 2 - - for { - logger.Info("Ping attempt %d", attempt) - - if latency, err := ping(tnet, dst); err != nil { - logger.Warn("Ping attempt %d failed: %v", attempt, err) - - // Increase delay after certain thresholds but cap it - if attempt%5 == 0 && retryDelay < maxRetryDelay { - retryDelay = time.Duration(float64(retryDelay) * 1.5) - if retryDelay > maxRetryDelay { - retryDelay = maxRetryDelay - } - logger.Info("Increasing ping retry delay to %v", retryDelay) - } - - time.Sleep(retryDelay) - attempt++ - } else { - // Successful ping - logger.Info("Ping succeeded after %d attempts", attempt) - logger.Info("Ping latency: %v", latency) - logger.Info("Tunnel connection to server established successfully!") - return - } - } - }() - - // Return an error for the first batch of attempts (to maintain compatibility with existing code) - return fmt.Errorf("initial ping attempts failed, continuing in background") -} - -func parseLogLevel(level string) logger.LogLevel { - switch strings.ToUpper(level) { - case "DEBUG": - return logger.DEBUG - case "INFO": - return logger.INFO - case "WARN": - return logger.WARN - case "ERROR": - return logger.ERROR - case "FATAL": - return logger.FATAL - default: - return logger.INFO // default to INFO if invalid level provided - } -} - -func mapToWireGuardLogLevel(level logger.LogLevel) int { - switch level { - case logger.DEBUG: - return device.LogLevelVerbose - // case logger.INFO: - // return device.LogLevel - case logger.WARN: - return device.LogLevelError - case logger.ERROR, logger.FATAL: - return device.LogLevelSilent - default: - return device.LogLevelSilent - } -} - -func resolveDomain(domain string) (string, error) { - // Check if there's a port in the domain - host, port, err := net.SplitHostPort(domain) - if err != nil { - // No port found, use the domain as is - host = domain - port = "" - } - - // Remove any protocol prefix if present - if strings.HasPrefix(host, "http://") { - host = strings.TrimPrefix(host, "http://") - } else if strings.HasPrefix(host, "https://") { - host = strings.TrimPrefix(host, "https://") - } - - // if there are any trailing slashes, remove them - host = strings.TrimSuffix(host, "/") - - // Lookup IP addresses - ips, err := net.LookupIP(host) - if err != nil { - return "", fmt.Errorf("DNS lookup failed: %v", err) - } - - if len(ips) == 0 { - return "", fmt.Errorf("no IP addresses found for domain %s", host) - } - - // Get the first IPv4 address if available - var ipAddr string - for _, ip := range ips { - if ipv4 := ip.To4(); ipv4 != nil { - ipAddr = ipv4.String() - break - } - } - - // If no IPv4 found, use the first IP (might be IPv6) - if ipAddr == "" { - ipAddr = ips[0].String() - } - - // Add port back if it existed - if port != "" { - ipAddr = net.JoinHostPort(ipAddr, port) - } - - return ipAddr, nil -} - var ( endpoint string id string @@ -524,17 +216,6 @@ func main() { } } - client.RegisterHandler("newt/terminate", func(msg websocket.WSMessage) { - logger.Info("Received terminate message") - if pm != nil { - pm.Stop() - } - if dev != nil { - dev.Close() - } - client.Close() - }) - pingStopChan := make(chan struct{}) defer close(pingStopChan) @@ -543,10 +224,32 @@ func main() { logger.Info("Received registration message") if connected { - logger.Info("Already connected! But I will send a ping anyway...") - // Even if pingWithRetry returns an error, it will continue trying in the background - _ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue - return + // Stop proxy manager if running + if pm != nil { + pm.Stop() + pm = nil + } + + // Close WireGuard device if running + if dev != nil { + dev.Close() + dev = nil + } + + // Close TUN/netstack if running + if tnet != nil { + tnet = nil + } + if tun != nil { + tun.Close() + tun = nil + } + + // Stop the ping check + close(pingStopChan) + + // Mark as disconnected + connected = false } jsonData, err := json.Marshal(msg.Data) @@ -612,10 +315,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub // as the pings will continue in the background if !connected { logger.Info("Starting ping check") - startPingCheck(tnet, wgData.ServerIP, pingStopChan) - - // Start connection monitoring in a separate goroutine - go monitorConnectionStatus(tnet, wgData.ServerIP, client) + startPingCheck(tnet, wgData.ServerIP, client, pingStopChan) } // Create proxy manager @@ -645,6 +345,39 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } }) + client.RegisterHandler("newt/wg/terminate", func(msg websocket.WSMessage) { + logger.Info("Received disconnect message") + + // Stop proxy manager if running + if pm != nil { + pm.Stop() + pm = nil + } + + // Close WireGuard device if running + if dev != nil { + dev.Close() + dev = nil + } + + // Close TUN/netstack if running + if tnet != nil { + tnet = nil + } + if tun != nil { + tun.Close() + tun = nil + } + + // Stop the ping check + close(pingStopChan) + + // Mark as disconnected + connected = false + + logger.Info("Tunnel destroyed, ready for reconnection") + }) + client.RegisterHandler("newt/ping/exitNodes", func(msg websocket.WSMessage) { logger.Info("Received ping message") @@ -747,7 +480,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub // check if last used node is close enough in score for _, cand := range candidateNodes { if cand.Node.WasPreviouslyConnected { - if bestScore - cand.Score <= bestScore*(scoreTolerancePercent/100.0) { + if bestScore-cand.Score <= bestScore*(scoreTolerancePercent/100.0) { logger.Info("Sticking with last used exit node: %s (%s), score close enough to best", cand.Node.Name, cand.Node.Endpoint) bestNode = &cand.Node } @@ -966,126 +699,75 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub os.Exit(0) } -func parseTargetData(data interface{}) (TargetData, error) { - var targetData TargetData - jsonData, err := json.Marshal(data) - if err != nil { - logger.Info("Error marshaling data: %v", err) - return targetData, err - } +func startPingCheck(tnet *netstack.Net, serverIP string, client *websocket.Client, stopChan chan struct{}) { + initialInterval := 10 * time.Second + maxInterval := 60 * time.Second + currentInterval := initialInterval + consecutiveFailures := 0 + connectionLost := false + ticker := time.NewTicker(currentInterval) + defer ticker.Stop() - if err := json.Unmarshal(jsonData, &targetData); err != nil { - logger.Info("Error unmarshaling target data: %v", err) - return targetData, err - } - return targetData, nil -} - -func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { - for _, t := range targetData.Targets { - // Split the first number off of the target with : separator and use as the port - parts := strings.Split(t, ":") - if len(parts) != 3 { - logger.Info("Invalid target format: %s", t) - continue - } - - // Get the port as an int - port := 0 - _, err := fmt.Sscanf(parts[0], "%d", &port) - if err != nil { - logger.Info("Invalid port: %s", parts[0]) - continue - } - - if action == "add" { - target := parts[1] + ":" + parts[2] - - // Call updown script if provided - processedTarget := target - if updownScript != "" { - newTarget, err := executeUpdownScript(action, proto, target) + go func() { + for { + select { + case <-ticker.C: + _, err := ping(tnet, serverIP) if err != nil { - logger.Warn("Updown script error: %v", err) - } else if newTarget != "" { - processedTarget = newTarget + consecutiveFailures++ + + // Check if this is the first failure (connection just lost) + if !connectionLost { + connectionLost = true + logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.") + logger.Warn("Please check your internet connection and ensure the Pangolin server is online.") + logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.") + } + + logger.Warn("Periodic ping failed (%d consecutive failures): %v", + consecutiveFailures, err) + logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?") + + // Increase interval if we have consistent failures, with a maximum cap + if consecutiveFailures >= 5 && currentInterval < maxInterval { + // Increase by 50% each time, up to the maximum + currentInterval = time.Duration(float64(currentInterval) * 1.5) + if currentInterval > maxInterval { + currentInterval = maxInterval + } + ticker.Reset(currentInterval) + logger.Debug("Increased ping check interval to %v due to consecutive failures", + currentInterval) + + // Restart the connection flow + err := client.SendMessage("newt/ping/request", map[string]interface{}{}) + if err != nil { + logger.Error("Failed to send ping request: %v", err) + } + } + } else { + // Check if connection was previously lost and is now restored + if connectionLost { + connectionLost = false + logger.Info("Connection to server restored!") + } + + // On success, if we've backed off, gradually return to normal interval + if currentInterval > initialInterval { + currentInterval = time.Duration(float64(currentInterval) * 0.8) + if currentInterval < initialInterval { + currentInterval = initialInterval + } + ticker.Reset(currentInterval) + logger.Info("Decreased ping check interval to %v after successful ping", + currentInterval) + } + consecutiveFailures = 0 } - } - - // Only remove the specific target if it exists - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - // Ignore "target not found" errors as this is expected for new targets - if !strings.Contains(err.Error(), "target not found") { - logger.Error("Failed to remove existing target: %v", err) - } - } - - // Add the new target - pm.AddTarget(proto, tunnelIP, port, processedTarget) - - } else if action == "remove" { - logger.Info("Removing target with port %d", port) - - target := parts[1] + ":" + parts[2] - - // Call updown script if provided - if updownScript != "" { - _, err := executeUpdownScript(action, proto, target) - if err != nil { - logger.Warn("Updown script error: %v", err) - } - } - - err := pm.RemoveTarget(proto, tunnelIP, port) - if err != nil { - logger.Error("Failed to remove target: %v", err) - return err + case <-stopChan: + logger.Info("Stopping ping check") + return } } - } - - return nil -} - -func executeUpdownScript(action, proto, target string) (string, error) { - if updownScript == "" { - return target, nil - } - - // Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py") - parts := strings.Fields(updownScript) - if len(parts) == 0 { - return target, fmt.Errorf("invalid updown script command") - } - - var cmd *exec.Cmd - if len(parts) == 1 { - // If it's a single executable - logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target) - cmd = exec.Command(parts[0], action, proto, target) - } else { - // If it includes interpreter and script - args := append(parts[1:], action, proto, target) - logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target) - cmd = exec.Command(parts[0], args...) - } - - output, err := cmd.Output() - if err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { - return "", fmt.Errorf("updown script execution failed (exit code %d): %s", - exitErr.ExitCode(), string(exitErr.Stderr)) - } - return "", fmt.Errorf("updown script execution failed: %v", err) - } - - // If the script returns a new target, use it - newTarget := strings.TrimSpace(string(output)) - if newTarget != "" { - logger.Info("Updown script returned new target: %s", newTarget) - return newTarget, nil - } - - return target, nil + }() } diff --git a/util.go b/util.go new file mode 100644 index 0000000..808ae2b --- /dev/null +++ b/util.go @@ -0,0 +1,352 @@ +package main + +import ( + "bytes" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "net" + "os/exec" + "strings" + "time" + + "github.com/fosrl/newt/logger" + "github.com/fosrl/newt/proxy" + "golang.org/x/exp/rand" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" +) + +func fixKey(key string) string { + // Remove any whitespace + key = strings.TrimSpace(key) + + // Decode from base64 + decoded, err := base64.StdEncoding.DecodeString(key) + if err != nil { + logger.Fatal("Error decoding base64: %v", err) + } + + // Convert to hex + return hex.EncodeToString(decoded) +} + +func ping(tnet *netstack.Net, dst string) (time.Duration, error) { + logger.Debug("Pinging %s", dst) + socket, err := tnet.Dial("ping4", dst) + if err != nil { + return 0, fmt.Errorf("failed to create ICMP socket: %w", err) + } + defer socket.Close() + + requestPing := icmp.Echo{ + Seq: rand.Intn(1 << 16), + Data: []byte("gopher burrow"), + } + + icmpBytes, err := (&icmp.Message{Type: ipv4.ICMPTypeEcho, Code: 0, Body: &requestPing}).Marshal(nil) + if err != nil { + return 0, fmt.Errorf("failed to marshal ICMP message: %w", err) + } + + if err := socket.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { + return 0, fmt.Errorf("failed to set read deadline: %w", err) + } + + start := time.Now() + _, err = socket.Write(icmpBytes) + if err != nil { + return 0, fmt.Errorf("failed to write ICMP packet: %w", err) + } + + n, err := socket.Read(icmpBytes[:]) + if err != nil { + return 0, fmt.Errorf("failed to read ICMP packet: %w", err) + } + + replyPacket, err := icmp.ParseMessage(1, icmpBytes[:n]) + if err != nil { + return 0, fmt.Errorf("failed to parse ICMP packet: %w", err) + } + + replyPing, ok := replyPacket.Body.(*icmp.Echo) + if !ok { + return 0, fmt.Errorf("invalid reply type: got %T, want *icmp.Echo", replyPacket.Body) + } + + if !bytes.Equal(replyPing.Data, requestPing.Data) || replyPing.Seq != requestPing.Seq { + return 0, fmt.Errorf("invalid ping reply: got seq=%d data=%q, want seq=%d data=%q", + replyPing.Seq, replyPing.Data, requestPing.Seq, requestPing.Data) + } + + latency := time.Since(start) + + return latency, nil +} + +func pingWithRetry(tnet *netstack.Net, dst string) error { + const ( + initialMaxAttempts = 15 + initialRetryDelay = 2 * time.Second + maxRetryDelay = 60 * time.Second // Cap the maximum delay + ) + + attempt := 1 + retryDelay := initialRetryDelay + + // First try with the initial parameters + logger.Info("Ping attempt %d", attempt) + if latency, err := ping(tnet, dst); err == nil { + // Successful ping + logger.Info("Ping latency: %v", latency) + + logger.Info("Tunnel connection to server established successfully!") + return nil + } else { + logger.Warn("Ping attempt %d failed: %v", attempt, err) + } + + // Start a goroutine that will attempt pings indefinitely with increasing delays + go func() { + attempt = 2 // Continue from attempt 2 + + for { + logger.Info("Ping attempt %d", attempt) + + if latency, err := ping(tnet, dst); err != nil { + logger.Warn("Ping attempt %d failed: %v", attempt, err) + + // Increase delay after certain thresholds but cap it + if attempt%5 == 0 && retryDelay < maxRetryDelay { + retryDelay = time.Duration(float64(retryDelay) * 1.5) + if retryDelay > maxRetryDelay { + retryDelay = maxRetryDelay + } + logger.Info("Increasing ping retry delay to %v", retryDelay) + } + + time.Sleep(retryDelay) + attempt++ + } else { + // Successful ping + logger.Info("Ping succeeded after %d attempts", attempt) + logger.Info("Ping latency: %v", latency) + logger.Info("Tunnel connection to server established successfully!") + return + } + } + }() + + // Return an error for the first batch of attempts (to maintain compatibility with existing code) + return fmt.Errorf("initial ping attempts failed, continuing in background") +} + +func parseLogLevel(level string) logger.LogLevel { + switch strings.ToUpper(level) { + case "DEBUG": + return logger.DEBUG + case "INFO": + return logger.INFO + case "WARN": + return logger.WARN + case "ERROR": + return logger.ERROR + case "FATAL": + return logger.FATAL + default: + return logger.INFO // default to INFO if invalid level provided + } +} + +func mapToWireGuardLogLevel(level logger.LogLevel) int { + switch level { + case logger.DEBUG: + return device.LogLevelVerbose + // case logger.INFO: + // return device.LogLevel + case logger.WARN: + return device.LogLevelError + case logger.ERROR, logger.FATAL: + return device.LogLevelSilent + default: + return device.LogLevelSilent + } +} + +func resolveDomain(domain string) (string, error) { + // Check if there's a port in the domain + host, port, err := net.SplitHostPort(domain) + if err != nil { + // No port found, use the domain as is + host = domain + port = "" + } + + // Remove any protocol prefix if present + if strings.HasPrefix(host, "http://") { + host = strings.TrimPrefix(host, "http://") + } else if strings.HasPrefix(host, "https://") { + host = strings.TrimPrefix(host, "https://") + } + + // if there are any trailing slashes, remove them + host = strings.TrimSuffix(host, "/") + + // Lookup IP addresses + ips, err := net.LookupIP(host) + if err != nil { + return "", fmt.Errorf("DNS lookup failed: %v", err) + } + + if len(ips) == 0 { + return "", fmt.Errorf("no IP addresses found for domain %s", host) + } + + // Get the first IPv4 address if available + var ipAddr string + for _, ip := range ips { + if ipv4 := ip.To4(); ipv4 != nil { + ipAddr = ipv4.String() + break + } + } + + // If no IPv4 found, use the first IP (might be IPv6) + if ipAddr == "" { + ipAddr = ips[0].String() + } + + // Add port back if it existed + if port != "" { + ipAddr = net.JoinHostPort(ipAddr, port) + } + + return ipAddr, nil +} + +func parseTargetData(data interface{}) (TargetData, error) { + var targetData TargetData + jsonData, err := json.Marshal(data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return targetData, err + } + + if err := json.Unmarshal(jsonData, &targetData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return targetData, err + } + return targetData, nil +} + +func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { + for _, t := range targetData.Targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 3 { + logger.Info("Invalid target format: %s", t) + continue + } + + // Get the port as an int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + logger.Info("Invalid port: %s", parts[0]) + continue + } + + if action == "add" { + target := parts[1] + ":" + parts[2] + + // Call updown script if provided + processedTarget := target + if updownScript != "" { + newTarget, err := executeUpdownScript(action, proto, target) + if err != nil { + logger.Warn("Updown script error: %v", err) + } else if newTarget != "" { + processedTarget = newTarget + } + } + + // Only remove the specific target if it exists + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + // Ignore "target not found" errors as this is expected for new targets + if !strings.Contains(err.Error(), "target not found") { + logger.Error("Failed to remove existing target: %v", err) + } + } + + // Add the new target + pm.AddTarget(proto, tunnelIP, port, processedTarget) + + } else if action == "remove" { + logger.Info("Removing target with port %d", port) + + target := parts[1] + ":" + parts[2] + + // Call updown script if provided + if updownScript != "" { + _, err := executeUpdownScript(action, proto, target) + if err != nil { + logger.Warn("Updown script error: %v", err) + } + } + + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + logger.Error("Failed to remove target: %v", err) + return err + } + } + } + + return nil +} + +func executeUpdownScript(action, proto, target string) (string, error) { + if updownScript == "" { + return target, nil + } + + // Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py") + parts := strings.Fields(updownScript) + if len(parts) == 0 { + return target, fmt.Errorf("invalid updown script command") + } + + var cmd *exec.Cmd + if len(parts) == 1 { + // If it's a single executable + logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target) + cmd = exec.Command(parts[0], action, proto, target) + } else { + // If it includes interpreter and script + args := append(parts[1:], action, proto, target) + logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target) + cmd = exec.Command(parts[0], args...) + } + + output, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return "", fmt.Errorf("updown script execution failed (exit code %d): %s", + exitErr.ExitCode(), string(exitErr.Stderr)) + } + return "", fmt.Errorf("updown script execution failed: %v", err) + } + + // If the script returns a new target, use it + newTarget := strings.TrimSpace(string(output)) + if newTarget != "" { + logger.Info("Updown script returned new target: %s", newTarget) + return newTarget, nil + } + + return target, nil +}