Update main.go

This commit is contained in:
Wouter van Elten
2025-06-30 12:50:33 +02:00
committed by GitHub
parent e357e7befb
commit 700287163e

295
main.go
View File

@@ -505,4 +505,299 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub
}
})
}
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
}
})
client.RegisterHandler("newt/udp/add", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
logger.Info("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
logger.Info("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
logger.Info("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
})
// Register handler for Docker socket check
client.RegisterHandler("newt/socket/check", func(msg websocket.WSMessage) {
logger.Info("Received Docker socket check request")
if dockerSocket == "" {
logger.Info("Docker socket path is not set")
err := client.SendMessage("newt/socket/status", map[string]interface{}{
"available": false,
"socketPath": dockerSocket,
})
if err != nil {
logger.Error("Failed to send Docker socket check response: %v", err)
}
return
}
// Check if Docker socket is available
isAvailable := docker.CheckSocket(dockerSocket)
// Send response back to server
err := client.SendMessage("newt/socket/status", map[string]interface{}{
"available": isAvailable,
"socketPath": dockerSocket,
})
if err != nil {
logger.Error("Failed to send Docker socket check response: %v", err)
} else {
logger.Info("Docker socket check response sent: available=%t", isAvailable)
}
})
// Register handler for Docker container listing
client.RegisterHandler("newt/socket/fetch", func(msg websocket.WSMessage) {
logger.Info("Received Docker container fetch request")
if dockerSocket == "" {
logger.Info("Docker socket path is not set")
return
}
// List Docker containers
containers, err := docker.ListContainers(dockerSocket)
if err != nil {
logger.Error("Failed to list Docker containers: %v", err)
return
}
// Send container list back to server
err = client.SendMessage("newt/socket/containers", map[string]interface{}{
"containers": containers,
})
if err != nil {
logger.Error("Failed to send Docker container list: %v", err)
} else {
logger.Info("Docker container list sent, count: %d", len(containers))
}
})
client.OnConnect(func() error {
publicKey := privateKey.PublicKey()
logger.Debug("Public key: %s", publicKey)
err := client.SendMessage("newt/wg/register", map[string]interface{}{
"publicKey": publicKey.String(),
})
if err != nil {
logger.Error("Failed to send registration message: %v", err)
return err
}
logger.Info("Sent registration message")
return nil
})
// Connect to the WebSocket server
if err := client.Connect(); err != nil {
logger.Fatal("Failed to connect to server: %v", err)
}
defer client.Close()
// Wait for interrupt signal
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
sigReceived := <-sigCh
// Cleanup
logger.Info("Received %s signal, stopping", sigReceived.String())
if dev != nil {
dev.Close()
}
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
logger.Info("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
logger.Info("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 3 {
logger.Info("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
logger.Info("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1] + ":" + parts[2]
// Call updown script if provided
processedTarget := target
if updownScript != "" {
newTarget, err := executeUpdownScript(action, proto, target)
if err != nil {
logger.Warn("Updown script error: %v", err)
} else if newTarget != "" {
processedTarget = newTarget
}
}
// Only remove the specific target if it exists
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
// Ignore "target not found" errors as this is expected for new targets
if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
}
}
// Add the new target
pm.AddTarget(proto, tunnelIP, port, processedTarget)
} else if action == "remove" {
logger.Info("Removing target with port %d", port)
target := parts[1] + ":" + parts[2]
// Call updown script if provided
if updownScript != "" {
_, err := executeUpdownScript(action, proto, target)
if err != nil {
logger.Warn("Updown script error: %v", err)
}
}
err := pm.RemoveTarget(proto, tunnelIP, port)
if err != nil {
logger.Error("Failed to remove target: %v", err)
return err
}
}
}
return nil
}
func executeUpdownScript(action, proto, target string) (string, error) {
if updownScript == "" {
return target, nil
}
// Split the updownScript in case it contains spaces (like "/usr/bin/python3 script.py")
parts := strings.Fields(updownScript)
if len(parts) == 0 {
return target, fmt.Errorf("invalid updown script command")
}
var cmd *exec.Cmd
if len(parts) == 1 {
// If it's a single executable
logger.Info("Executing updown script: %s %s %s %s", updownScript, action, proto, target)
cmd = exec.Command(parts[0], action, proto, target)
} else {
// If it includes interpreter and script
args := append(parts[1:], action, proto, target)
logger.Info("Executing updown script: %s %s %s %s %s", parts[0], strings.Join(parts[1:], " "), action, proto, target)
cmd = exec.Command(parts[0], args...)
}
output, err := cmd.Output()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
return "", fmt.Errorf("updown script execution failed (exit code %d): %s",
exitErr.ExitCode(), string(exitErr.Stderr))
}
return "", fmt.Errorf("updown script execution failed: %v", err)
}
// If the script returns a new target, use it
newTarget := strings.TrimSpace(string(output))
if newTarget != "" {
logger.Info("Updown script returned new target: %s", newTarget)
return newTarget, nil
}
return target, nil
}