mirror of
https://github.com/fosrl/newt.git
synced 2026-02-24 22:06:38 +00:00
Update main.go
Added cli and env function
This commit is contained in:
403
main.go
403
main.go
@@ -50,16 +50,11 @@ type TargetData struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func fixKey(key string) string {
|
func fixKey(key string) string {
|
||||||
// Remove any whitespace
|
|
||||||
key = strings.TrimSpace(key)
|
key = strings.TrimSpace(key)
|
||||||
|
|
||||||
// Decode from base64
|
|
||||||
decoded, err := base64.StdEncoding.DecodeString(key)
|
decoded, err := base64.StdEncoding.DecodeString(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Error decoding base64: %v", err)
|
logger.Fatal("Error decoding base64: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Convert to hex
|
|
||||||
return hex.EncodeToString(decoded)
|
return hex.EncodeToString(decoded)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -115,7 +110,8 @@ func ping(tnet *netstack.Net, dst string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}) {
|
// --- CHANGED: added healthFile as parameter ---
|
||||||
|
func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{}, healthFile string) {
|
||||||
initialInterval := 10 * time.Second
|
initialInterval := 10 * time.Second
|
||||||
maxInterval := 60 * time.Second
|
maxInterval := 60 * time.Second
|
||||||
currentInterval := initialInterval
|
currentInterval := initialInterval
|
||||||
@@ -133,13 +129,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
|||||||
consecutiveFailures++
|
consecutiveFailures++
|
||||||
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
|
logger.Warn("Periodic ping failed (%d consecutive failures): %v", consecutiveFailures, err)
|
||||||
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
logger.Warn("HINT: Do you have UDP port 51820 (or the port in config.yml) open on your Pangolin server?")
|
||||||
// delete healthy file if failed 3 times
|
// --- CHANGED: Only remove file if healthFile is set ---
|
||||||
if consecutiveFailures >= 3 {
|
if consecutiveFailures >= 3 && healthFile != "" {
|
||||||
_ = os.Remove("/tmp/healthy")
|
_ = os.Remove(healthFile)
|
||||||
}
|
}
|
||||||
// Increase interval if we have consistent failures, with a maximum cap
|
// Increase interval if we have consistent failures, with a maximum cap
|
||||||
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
if consecutiveFailures >= 3 && currentInterval < maxInterval {
|
||||||
// Increase by 50% each time, up to the maximum
|
|
||||||
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
currentInterval = time.Duration(float64(currentInterval) * 1.5)
|
||||||
if currentInterval > maxInterval {
|
if currentInterval > maxInterval {
|
||||||
currentInterval = maxInterval
|
currentInterval = maxInterval
|
||||||
@@ -148,10 +143,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
|||||||
logger.Info("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
logger.Info("Increased ping check interval to %v due to consecutive failures", currentInterval)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Write a healthy file if ping successfull
|
// --- CHANGED: Only write file if healthFile is set ---
|
||||||
err := os.WriteFile("/tmp/healthy", []byte("ok"), 0644)
|
if healthFile != "" {
|
||||||
if err != nil {
|
err := os.WriteFile(healthFile, []byte("ok"), 0644)
|
||||||
logger.Warn("Failed to write health file: %v", err)
|
if err != nil {
|
||||||
|
logger.Warn("Failed to write health file: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// On success, if we've backed off, gradually return to normal interval
|
// On success, if we've backed off, gradually return to normal interval
|
||||||
if currentInterval > initialInterval {
|
if currentInterval > initialInterval {
|
||||||
@@ -166,13 +163,12 @@ func startPingCheck(tnet *netstack.Net, serverIP string, stopChan chan struct{})
|
|||||||
}
|
}
|
||||||
case <-stopChan:
|
case <-stopChan:
|
||||||
logger.Info("Stopping ping check")
|
logger.Info("Stopping ping check")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function to track connection status and trigger reconnection as needed
|
|
||||||
func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) {
|
func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websocket.Client) {
|
||||||
const checkInterval = 30 * time.Second
|
const checkInterval = 30 * time.Second
|
||||||
connectionLost := false
|
connectionLost := false
|
||||||
@@ -182,27 +178,18 @@ func monitorConnectionStatus(tnet *netstack.Net, serverIP string, client *websoc
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Try a ping to see if connection is alive
|
|
||||||
err := ping(tnet, serverIP)
|
err := ping(tnet, serverIP)
|
||||||
|
|
||||||
if err != nil && !connectionLost {
|
if err != nil && !connectionLost {
|
||||||
// We just lost connection
|
|
||||||
connectionLost = true
|
connectionLost = true
|
||||||
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
logger.Warn("Connection to server lost. Continuous reconnection attempts will be made.")
|
||||||
|
|
||||||
// Notify the user they might need to check their network
|
|
||||||
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
|
logger.Warn("Please check your internet connection and ensure the Pangolin server is online.")
|
||||||
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
|
logger.Warn("Newt will continue reconnection attempts automatically when connectivity is restored.")
|
||||||
} else if err == nil && connectionLost {
|
} else if err == nil && connectionLost {
|
||||||
// Connection has been restored
|
|
||||||
connectionLost = false
|
connectionLost = false
|
||||||
logger.Info("Connection to server restored!")
|
logger.Info("Connection to server restored!")
|
||||||
|
|
||||||
// Tell the server we're back
|
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
||||||
"publicKey": privateKey.PublicKey().String(),
|
"publicKey": privateKey.PublicKey().String(),
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to send registration message after reconnection: %v", err)
|
logger.Error("Failed to send registration message after reconnection: %v", err)
|
||||||
} else {
|
} else {
|
||||||
@@ -217,32 +204,25 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|||||||
const (
|
const (
|
||||||
initialMaxAttempts = 15
|
initialMaxAttempts = 15
|
||||||
initialRetryDelay = 2 * time.Second
|
initialRetryDelay = 2 * time.Second
|
||||||
maxRetryDelay = 60 * time.Second // Cap the maximum delay
|
maxRetryDelay = 60 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
attempt := 1
|
attempt := 1
|
||||||
retryDelay := initialRetryDelay
|
retryDelay := initialRetryDelay
|
||||||
|
|
||||||
// First try with the initial parameters
|
|
||||||
logger.Info("Ping attempt %d", attempt)
|
logger.Info("Ping attempt %d", attempt)
|
||||||
if err := ping(tnet, dst); err == nil {
|
if err := ping(tnet, dst); err == nil {
|
||||||
// Successful ping
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start a goroutine that will attempt pings indefinitely with increasing delays
|
|
||||||
go func() {
|
go func() {
|
||||||
attempt = 2 // Continue from attempt 2
|
attempt = 2
|
||||||
|
|
||||||
for {
|
for {
|
||||||
logger.Info("Ping attempt %d", attempt)
|
logger.Info("Ping attempt %d", attempt)
|
||||||
|
|
||||||
if err := ping(tnet, dst); err != nil {
|
if err := ping(tnet, dst); err != nil {
|
||||||
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
logger.Warn("Ping attempt %d failed: %v", attempt, err)
|
||||||
|
|
||||||
// Increase delay after certain thresholds but cap it
|
|
||||||
if attempt%5 == 0 && retryDelay < maxRetryDelay {
|
if attempt%5 == 0 && retryDelay < maxRetryDelay {
|
||||||
retryDelay = time.Duration(float64(retryDelay) * 1.5)
|
retryDelay = time.Duration(float64(retryDelay) * 1.5)
|
||||||
if retryDelay > maxRetryDelay {
|
if retryDelay > maxRetryDelay {
|
||||||
@@ -250,18 +230,14 @@ func pingWithRetry(tnet *netstack.Net, dst string) error {
|
|||||||
}
|
}
|
||||||
logger.Info("Increasing ping retry delay to %v", retryDelay)
|
logger.Info("Increasing ping retry delay to %v", retryDelay)
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(retryDelay)
|
time.Sleep(retryDelay)
|
||||||
attempt++
|
attempt++
|
||||||
} else {
|
} else {
|
||||||
// Successful ping
|
|
||||||
logger.Info("Ping succeeded after %d attempts", attempt)
|
logger.Info("Ping succeeded after %d attempts", attempt)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Return an error for the first batch of attempts (to maintain compatibility with existing code)
|
|
||||||
return fmt.Errorf("initial ping attempts failed, continuing in background")
|
return fmt.Errorf("initial ping attempts failed, continuing in background")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -278,7 +254,7 @@ func parseLogLevel(level string) logger.LogLevel {
|
|||||||
case "FATAL":
|
case "FATAL":
|
||||||
return logger.FATAL
|
return logger.FATAL
|
||||||
default:
|
default:
|
||||||
return logger.INFO // default to INFO if invalid level provided
|
return logger.INFO
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,8 +262,6 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
|||||||
switch level {
|
switch level {
|
||||||
case logger.DEBUG:
|
case logger.DEBUG:
|
||||||
return device.LogLevelVerbose
|
return device.LogLevelVerbose
|
||||||
// case logger.INFO:
|
|
||||||
// return device.LogLevel
|
|
||||||
case logger.WARN:
|
case logger.WARN:
|
||||||
return device.LogLevelError
|
return device.LogLevelError
|
||||||
case logger.ERROR, logger.FATAL:
|
case logger.ERROR, logger.FATAL:
|
||||||
@@ -298,32 +272,23 @@ func mapToWireGuardLogLevel(level logger.LogLevel) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func resolveDomain(domain string) (string, error) {
|
func resolveDomain(domain string) (string, error) {
|
||||||
// Check if there's a port in the domain
|
|
||||||
host, port, err := net.SplitHostPort(domain)
|
host, port, err := net.SplitHostPort(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// No port found, use the domain as is
|
|
||||||
host = domain
|
host = domain
|
||||||
port = ""
|
port = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove any protocol prefix if present
|
|
||||||
if strings.HasPrefix(host, "http://") {
|
if strings.HasPrefix(host, "http://") {
|
||||||
host = strings.TrimPrefix(host, "http://")
|
host = strings.TrimPrefix(host, "http://")
|
||||||
} else if strings.HasPrefix(host, "https://") {
|
} else if strings.HasPrefix(host, "https://") {
|
||||||
host = strings.TrimPrefix(host, "https://")
|
host = strings.TrimPrefix(host, "https://")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Lookup IP addresses
|
|
||||||
ips, err := net.LookupIP(host)
|
ips, err := net.LookupIP(host)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
return "", fmt.Errorf("DNS lookup failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ips) == 0 {
|
if len(ips) == 0 {
|
||||||
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
return "", fmt.Errorf("no IP addresses found for domain %s", host)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the first IPv4 address if available
|
|
||||||
var ipAddr string
|
var ipAddr string
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
@@ -331,20 +296,16 @@ func resolveDomain(domain string) (string, error) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no IPv4 found, use the first IP (might be IPv6)
|
|
||||||
if ipAddr == "" {
|
if ipAddr == "" {
|
||||||
ipAddr = ips[0].String()
|
ipAddr = ips[0].String()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add port back if it existed
|
|
||||||
if port != "" {
|
if port != "" {
|
||||||
ipAddr = net.JoinHostPort(ipAddr, port)
|
ipAddr = net.JoinHostPort(ipAddr, port)
|
||||||
}
|
}
|
||||||
|
|
||||||
return ipAddr, nil
|
return ipAddr, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- ADDED: healthFile variable ---
|
||||||
var (
|
var (
|
||||||
endpoint string
|
endpoint string
|
||||||
id string
|
id string
|
||||||
@@ -358,10 +319,10 @@ var (
|
|||||||
updownScript string
|
updownScript string
|
||||||
tlsPrivateKey string
|
tlsPrivateKey string
|
||||||
dockerSocket string
|
dockerSocket string
|
||||||
|
healthFile string // NEW
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
// if PANGOLIN_ENDPOINT, NEWT_ID, and NEWT_SECRET are set as environment variables, they will be used as default values
|
|
||||||
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
endpoint = os.Getenv("PANGOLIN_ENDPOINT")
|
||||||
id = os.Getenv("NEWT_ID")
|
id = os.Getenv("NEWT_ID")
|
||||||
secret = os.Getenv("NEWT_SECRET")
|
secret = os.Getenv("NEWT_SECRET")
|
||||||
@@ -371,6 +332,7 @@ func main() {
|
|||||||
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
updownScript = os.Getenv("UPDOWN_SCRIPT")
|
||||||
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
tlsPrivateKey = os.Getenv("TLS_CLIENT_CERT")
|
||||||
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
dockerSocket = os.Getenv("DOCKER_SOCKET")
|
||||||
|
healthFile = os.Getenv("HEALTH_FILE") // NEW
|
||||||
|
|
||||||
if endpoint == "" {
|
if endpoint == "" {
|
||||||
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
flag.StringVar(&endpoint, "endpoint", "", "Endpoint of your pangolin server")
|
||||||
@@ -399,10 +361,12 @@ func main() {
|
|||||||
if dockerSocket == "" {
|
if dockerSocket == "" {
|
||||||
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
|
flag.StringVar(&dockerSocket, "docker-socket", "", "Path to Docker socket (typically /var/run/docker.sock)")
|
||||||
}
|
}
|
||||||
|
// --- ADDED: CLI flag for healthFile if not set by env ---
|
||||||
|
if healthFile == "" {
|
||||||
|
flag.StringVar(&healthFile, "health-file", "", "Path to health file (if unset, health file won’t be written)")
|
||||||
|
}
|
||||||
|
|
||||||
// do a --version check
|
|
||||||
version := flag.Bool("version", false, "Print the version")
|
version := flag.Bool("version", false, "Print the version")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
newtVersion := "Newt version replaceme"
|
newtVersion := "Newt version replaceme"
|
||||||
@@ -417,7 +381,6 @@ func main() {
|
|||||||
loggerLevel := parseLogLevel(logLevel)
|
loggerLevel := parseLogLevel(logLevel)
|
||||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||||
|
|
||||||
// parse the mtu string into an int
|
|
||||||
mtuInt, err = strconv.Atoi(mtu)
|
mtuInt, err = strconv.Atoi(mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to parse MTU: %v", err)
|
logger.Fatal("Failed to parse MTU: %v", err)
|
||||||
@@ -431,18 +394,13 @@ func main() {
|
|||||||
if tlsPrivateKey != "" {
|
if tlsPrivateKey != "" {
|
||||||
opt = websocket.WithTLSConfig(tlsPrivateKey)
|
opt = websocket.WithTLSConfig(tlsPrivateKey)
|
||||||
}
|
}
|
||||||
// Create a new client
|
|
||||||
client, err := websocket.NewClient(
|
client, err := websocket.NewClient(
|
||||||
id, // CLI arg takes precedence
|
id, secret, endpoint, opt,
|
||||||
secret, // CLI arg takes precedence
|
|
||||||
endpoint,
|
|
||||||
opt,
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal("Failed to create client: %v", err)
|
logger.Fatal("Failed to create client: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create TUN device and network stack
|
|
||||||
var tun tun.Device
|
var tun tun.Device
|
||||||
var tnet *netstack.Net
|
var tnet *netstack.Net
|
||||||
var dev *device.Device
|
var dev *device.Device
|
||||||
@@ -464,14 +422,12 @@ func main() {
|
|||||||
pingStopChan := make(chan struct{})
|
pingStopChan := make(chan struct{})
|
||||||
defer close(pingStopChan)
|
defer close(pingStopChan)
|
||||||
|
|
||||||
// Register handlers for different message types
|
|
||||||
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
client.RegisterHandler("newt/wg/connect", func(msg websocket.WSMessage) {
|
||||||
logger.Info("Received registration message")
|
logger.Info("Received registration message")
|
||||||
|
|
||||||
if connected {
|
if connected {
|
||||||
logger.Info("Already connected! But I will send a ping anyway...")
|
logger.Info("Already connected! But I will send a ping anyway...")
|
||||||
// Even if pingWithRetry returns an error, it will continue trying in the background
|
_ = pingWithRetry(tnet, wgData.ServerIP)
|
||||||
_ = pingWithRetry(tnet, wgData.ServerIP) // Ignoring initial error as pings will continue
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -495,7 +451,6 @@ func main() {
|
|||||||
logger.Error("Failed to create TUN device: %v", err)
|
logger.Error("Failed to create TUN device: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create WireGuard device
|
|
||||||
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
dev = device.NewDevice(tun, conn.NewDefaultBind(), device.NewLogger(
|
||||||
mapToWireGuardLogLevel(loggerLevel),
|
mapToWireGuardLogLevel(loggerLevel),
|
||||||
"wireguard: ",
|
"wireguard: ",
|
||||||
@@ -507,7 +462,6 @@ func main() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure WireGuard
|
|
||||||
config := fmt.Sprintf(`private_key=%s
|
config := fmt.Sprintf(`private_key=%s
|
||||||
public_key=%s
|
public_key=%s
|
||||||
allowed_ip=%s/32
|
allowed_ip=%s/32
|
||||||
@@ -519,7 +473,6 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
logger.Error("Failed to configure WireGuard device: %v", err)
|
logger.Error("Failed to configure WireGuard device: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bring up the device
|
|
||||||
err = dev.Up()
|
err = dev.Up()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error("Failed to bring up WireGuard device: %v", err)
|
logger.Error("Failed to bring up WireGuard device: %v", err)
|
||||||
@@ -527,29 +480,21 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
|
|
||||||
logger.Info("WireGuard device created. Lets ping the server now...")
|
logger.Info("WireGuard device created. Lets ping the server now...")
|
||||||
|
|
||||||
// Even if pingWithRetry returns an error, it will continue trying in the background
|
|
||||||
_ = pingWithRetry(tnet, wgData.ServerIP)
|
_ = pingWithRetry(tnet, wgData.ServerIP)
|
||||||
|
|
||||||
// Always mark as connected and start the proxy manager regardless of initial ping result
|
|
||||||
// as the pings will continue in the background
|
|
||||||
if !connected {
|
if !connected {
|
||||||
logger.Info("Starting ping check")
|
logger.Info("Starting ping check")
|
||||||
startPingCheck(tnet, wgData.ServerIP, pingStopChan)
|
// --- CHANGED: Pass healthFile to startPingCheck ---
|
||||||
|
startPingCheck(tnet, wgData.ServerIP, pingStopChan, healthFile)
|
||||||
// Start connection monitoring in a separate goroutine
|
|
||||||
go monitorConnectionStatus(tnet, wgData.ServerIP, client)
|
go monitorConnectionStatus(tnet, wgData.ServerIP, client)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create proxy manager
|
|
||||||
pm = proxy.NewProxyManager(tnet)
|
pm = proxy.NewProxyManager(tnet)
|
||||||
|
|
||||||
connected = true
|
connected = true
|
||||||
|
|
||||||
// add the targets if there are any
|
|
||||||
if len(wgData.Targets.TCP) > 0 {
|
if len(wgData.Targets.TCP) > 0 {
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(wgData.Targets.UDP) > 0 {
|
if len(wgData.Targets.UDP) > 0 {
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
|
||||||
}
|
}
|
||||||
@@ -560,298 +505,4 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received: %+v", msg)
|
|
||||||
|
|
||||||
// if there is no wgData or pm, we can't add targets
|
|
||||||
if wgData.TunnelIP == "" || pm == nil {
|
|
||||||
logger.Info("No tunnel IP or proxy manager available")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
targetData, err := parseTargetData(msg.Data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error parsing target data: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
|
||||||
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Register handler for Docker socket check
|
|
||||||
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received Docker socket check request")
|
|
||||||
|
|
||||||
if dockerSocket == "" {
|
|
||||||
logger.Info("Docker socket path is not set")
|
|
||||||
err := client.SendMessage("newt/socket/status", map[string]interface{}{
|
|
||||||
"available": false,
|
|
||||||
"socketPath": dockerSocket,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send Docker socket check response: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if Docker socket is available
|
|
||||||
isAvailable := docker.CheckSocket(dockerSocket)
|
|
||||||
|
|
||||||
// Send response back to server
|
|
||||||
err := client.SendMessage("newt/socket/status", map[string]interface{}{
|
|
||||||
"available": isAvailable,
|
|
||||||
"socketPath": dockerSocket,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send Docker socket check response: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Docker socket check response sent: available=%t", isAvailable)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Register handler for Docker container listing
|
|
||||||
client.RegisterHandler("newt/socket/fetch", func(msg websocket.WSMessage) {
|
|
||||||
logger.Info("Received Docker container fetch request")
|
|
||||||
|
|
||||||
if dockerSocket == "" {
|
|
||||||
logger.Info("Docker socket path is not set")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// List Docker containers
|
|
||||||
containers, err := docker.ListContainers(dockerSocket)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to list Docker containers: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send container list back to server
|
|
||||||
err = client.SendMessage("newt/socket/containers", map[string]interface{}{
|
|
||||||
"containers": containers,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send Docker container list: %v", err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Docker container list sent, count: %d", len(containers))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
client.OnConnect(func() error {
|
|
||||||
publicKey := privateKey.PublicKey()
|
|
||||||
logger.Debug("Public key: %s", publicKey)
|
|
||||||
|
|
||||||
err := client.SendMessage("newt/wg/register", map[string]interface{}{
|
|
||||||
"publicKey": publicKey.String(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to send registration message: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Info("Sent registration message")
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
|
|
||||||
// Connect to the WebSocket server
|
|
||||||
if err := client.Connect(); err != nil {
|
|
||||||
logger.Fatal("Failed to connect to server: %v", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
// Wait for interrupt signal
|
|
||||||
sigCh := make(chan os.Signal, 1)
|
|
||||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
sigReceived := <-sigCh
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
logger.Info("Received %s signal, stopping", sigReceived.String())
|
|
||||||
if dev != nil {
|
|
||||||
dev.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseTargetData(data interface{}) (TargetData, error) {
|
|
||||||
var targetData TargetData
|
|
||||||
jsonData, err := json.Marshal(data)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Error marshaling data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(jsonData, &targetData); err != nil {
|
|
||||||
logger.Info("Error unmarshaling target data: %v", err)
|
|
||||||
return targetData, err
|
|
||||||
}
|
|
||||||
return targetData, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
|
||||||
for _, t := range targetData.Targets {
|
|
||||||
// Split the first number off of the target with : separator and use as the port
|
|
||||||
parts := strings.Split(t, ":")
|
|
||||||
if len(parts) != 3 {
|
|
||||||
logger.Info("Invalid target format: %s", t)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the port as an int
|
|
||||||
port := 0
|
|
||||||
_, err := fmt.Sscanf(parts[0], "%d", &port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Info("Invalid port: %s", parts[0])
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if action == "add" {
|
|
||||||
target := parts[1] + ":" + parts[2]
|
|
||||||
|
|
||||||
// Call updown script if provided
|
|
||||||
processedTarget := target
|
|
||||||
if updownScript != "" {
|
|
||||||
newTarget, err := executeUpdownScript(action, proto, target)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Updown script error: %v", err)
|
|
||||||
} else if newTarget != "" {
|
|
||||||
processedTarget = newTarget
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Only remove the specific target if it exists
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
|
||||||
// Ignore "target not found" errors as this is expected for new targets
|
|
||||||
if !strings.Contains(err.Error(), "target not found") {
|
|
||||||
logger.Error("Failed to remove existing target: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the new target
|
|
||||||
pm.AddTarget(proto, tunnelIP, port, processedTarget)
|
|
||||||
|
|
||||||
} else if action == "remove" {
|
|
||||||
logger.Info("Removing target with port %d", port)
|
|
||||||
|
|
||||||
target := parts[1] + ":" + parts[2]
|
|
||||||
|
|
||||||
// Call updown script if provided
|
|
||||||
if updownScript != "" {
|
|
||||||
_, err := executeUpdownScript(action, proto, target)
|
|
||||||
if err != nil {
|
|
||||||
logger.Warn("Updown script error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err := pm.RemoveTarget(proto, tunnelIP, port)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error("Failed to remove target: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func executeUpdownScript(action, proto, target string) (string, error) {
|
|
||||||
if updownScript == "" {
|
|
||||||
return target, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py")
|
|
||||||
parts := strings.Fields(updownScript)
|
|
||||||
if len(parts) == 0 {
|
|
||||||
return target, fmt.Errorf("invalid updown script command")
|
|
||||||
}
|
|
||||||
|
|
||||||
var cmd *exec.Cmd
|
|
||||||
if len(parts) == 1 {
|
|
||||||
// If it's a single executable
|
|
||||||
logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target)
|
|
||||||
cmd = exec.Command(parts[0], action, proto, target)
|
|
||||||
} else {
|
|
||||||
// If it includes interpreter and script
|
|
||||||
args := append(parts[1:], action, proto, target)
|
|
||||||
logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target)
|
|
||||||
cmd = exec.Command(parts[0], args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
output, err := cmd.Output()
|
|
||||||
if err != nil {
|
|
||||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
|
||||||
return "", fmt.Errorf("updown script execution failed (exit code %d): %s",
|
|
||||||
exitErr.ExitCode(), string(exitErr.Stderr))
|
|
||||||
}
|
|
||||||
return "", fmt.Errorf("updown script execution failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the script returns a new target, use it
|
|
||||||
newTarget := strings.TrimSpace(string(output))
|
|
||||||
if newTarget != "" {
|
|
||||||
logger.Info("Updown script returned new target: %s", newTarget)
|
|
||||||
return newTarget, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return target, nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user