diff --git a/main.go b/main.go index 116e6e5..ef85099 100644 --- a/main.go +++ b/main.go @@ -505,4 +505,299 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub } }) +} + client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) { + logger.Info("Received: %+v", msg) + + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + logger.Info("No tunnel IP or proxy manager available") + return + } + + targetData, err := parseTargetData(msg.Data) + if err != nil { + logger.Info("Error parsing target data: %v", err) + return + } + + if len(targetData.Targets) > 0 { + updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData) + } + }) + + client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) { + logger.Info("Received: %+v", msg) + + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + logger.Info("No tunnel IP or proxy manager available") + return + } + + targetData, err := parseTargetData(msg.Data) + if err != nil { + logger.Info("Error parsing target data: %v", err) + return + } + + if len(targetData.Targets) > 0 { + updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData) + } + }) + + client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) { + logger.Info("Received: %+v", msg) + + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + logger.Info("No tunnel IP or proxy manager available") + return + } + + targetData, err := parseTargetData(msg.Data) + if err != nil { + logger.Info("Error parsing target data: %v", err) + return + } + + if len(targetData.Targets) > 0 { + updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData) + } + }) + + client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) { + logger.Info("Received: %+v", msg) + + // if there is no wgData or pm, we can't add targets + if wgData.TunnelIP == "" || pm == nil { + logger.Info("No tunnel IP or proxy manager available") + return + } + + targetData, err := parseTargetData(msg.Data) + if err != nil { + logger.Info("Error parsing target data: %v", err) + return + } + + if len(targetData.Targets) > 0 { + updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData) + } + }) + + // Register handler for Docker socket check + client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) { + logger.Info("Received Docker socket check request") + + if dockerSocket == "" { + logger.Info("Docker socket path is not set") + err := client.SendMessage("newt/socket/status", map[string]interface{}{ + "available": false, + "socketPath": dockerSocket, + }) + if err != nil { + logger.Error("Failed to send Docker socket check response: %v", err) + } + return + } + + // Check if Docker socket is available + isAvailable := docker.CheckSocket(dockerSocket) + + // Send response back to server + err := client.SendMessage("newt/socket/status", map[string]interface{}{ + "available": isAvailable, + "socketPath": dockerSocket, + }) + if err != nil { + logger.Error("Failed to send Docker socket check response: %v", err) + } else { + logger.Info("Docker socket check response sent: available=%t", isAvailable) + } + }) + + // Register handler for Docker container listing + client.RegisterHandler("newt/socket/fetch", func(msg websocket.WSMessage) { + logger.Info("Received Docker container fetch request") + + if dockerSocket == "" { + logger.Info("Docker socket path is not set") + return + } + + // List Docker containers + containers, err := docker.ListContainers(dockerSocket) + if err != nil { + logger.Error("Failed to list Docker containers: %v", err) + return + } + + // Send container list back to server + err = client.SendMessage("newt/socket/containers", map[string]interface{}{ + "containers": containers, + }) + if err != nil { + logger.Error("Failed to send Docker container list: %v", err) + } else { + logger.Info("Docker container list sent, count: %d", len(containers)) + } + }) + + client.OnConnect(func() error { + publicKey := privateKey.PublicKey() + logger.Debug("Public key: %s", publicKey) + + err := client.SendMessage("newt/wg/register", map[string]interface{}{ + "publicKey": publicKey.String(), + }) + if err != nil { + logger.Error("Failed to send registration message: %v", err) + return err + } + + logger.Info("Sent registration message") + return nil + }) + + // Connect to the WebSocket server + if err := client.Connect(); err != nil { + logger.Fatal("Failed to connect to server: %v", err) + } + defer client.Close() + + // Wait for interrupt signal + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + sigReceived := <-sigCh + + // Cleanup + logger.Info("Received %s signal, stopping", sigReceived.String()) + if dev != nil { + dev.Close() + } +} + +func parseTargetData(data interface{}) (TargetData, error) { + var targetData TargetData + jsonData, err := json.Marshal(data) + if err != nil { + logger.Info("Error marshaling data: %v", err) + return targetData, err + } + + if err := json.Unmarshal(jsonData, &targetData); err != nil { + logger.Info("Error unmarshaling target data: %v", err) + return targetData, err + } + return targetData, nil +} + +func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { + for _, t := range targetData.Targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 3 { + logger.Info("Invalid target format: %s", t) + continue + } + + // Get the port as an int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + logger.Info("Invalid port: %s", parts[0]) + continue + } + + if action == "add" { + target := parts[1] + ":" + parts[2] + + // Call updown script if provided + processedTarget := target + if updownScript != "" { + newTarget, err := executeUpdownScript(action, proto, target) + if err != nil { + logger.Warn("Updown script error: %v", err) + } else if newTarget != "" { + processedTarget = newTarget + } + } + + // Only remove the specific target if it exists + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + // Ignore "target not found" errors as this is expected for new targets + if !strings.Contains(err.Error(), "target not found") { + logger.Error("Failed to remove existing target: %v", err) + } + } + + // Add the new target + pm.AddTarget(proto, tunnelIP, port, processedTarget) + + } else if action == "remove" { + logger.Info("Removing target with port %d", port) + + target := parts[1] + ":" + parts[2] + + // Call updown script if provided + if updownScript != "" { + _, err := executeUpdownScript(action, proto, target) + if err != nil { + logger.Warn("Updown script error: %v", err) + } + } + + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + logger.Error("Failed to remove target: %v", err) + return err + } + } + } + + return nil +} + +func executeUpdownScript(action, proto, target string) (string, error) { + if updownScript == "" { + return target, nil + } + + // Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py") + parts := strings.Fields(updownScript) + if len(parts) == 0 { + return target, fmt.Errorf("invalid updown script command") + } + + var cmd *exec.Cmd + if len(parts) == 1 { + // If it's a single executable + logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target) + cmd = exec.Command(parts[0], action, proto, target) + } else { + // If it includes interpreter and script + args := append(parts[1:], action, proto, target) + logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target) + cmd = exec.Command(parts[0], args...) + } + + output, err := cmd.Output() + if err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + return "", fmt.Errorf("updown script execution failed (exit code %d): %s", + exitErr.ExitCode(), string(exitErr.Stderr)) + } + return "", fmt.Errorf("updown script execution failed: %v", err) + } + + // If the script returns a new target, use it + newTarget := strings.TrimSpace(string(output)) + if newTarget != "" { + logger.Info("Updown script returned new target: %s", newTarget) + return newTarget, nil + } + + return target, nil }