Polish; add remove

This commit is contained in:
Owen Schwartz
2024-11-18 22:08:42 -05:00
parent 2e5531b4a5
commit 055d50d1d3
5 changed files with 146 additions and 98 deletions

207
main.go
View File

@@ -28,10 +28,20 @@ import (
)
type WgData struct {
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
TunnelIP string `json:"tunnelIP"`
Endpoint string `json:"endpoint"`
PublicKey string `json:"publicKey"`
ServerIP string `json:"serverIP"`
TunnelIP string `json:"tunnelIP"`
Targets TargetsByType `json:"targets"`
}
type TargetsByType struct {
UDP []string `json:"udp"`
TCP []string `json:"tcp"`
}
type TargetData struct {
Targets []string `json:"targets"`
}
func fixKey(key string) string {
@@ -177,6 +187,15 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
pm = proxy.NewProxyManager(tnet)
connected = true
// add the targets if there are any
if len(wgData.Targets.TCP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP})
}
if len(wgData.Targets.UDP) > 0 {
updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP})
}
})
client.RegisterHandler("newt/tcp/add", func(msg websocket.WSMessage) {
@@ -188,55 +207,14 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
return
}
type TargetData struct {
Targets []string `json:"targets"`
}
// Define a struct for the expected data structure
jsonData, err := json.Marshal(msg.Data)
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error marshaling data: %v", err)
return
}
// Parse into our target structure
var targetData TargetData
if err := json.Unmarshal(jsonData, &targetData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
log.Printf("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
// Stop the proxy manager before adding new targets
err = pm.Stop()
if err != nil {
log.Panic(err)
}
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 2 {
log.Printf("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
log.Printf("Invalid port: %s", parts[0])
continue
}
target := parts[1]
pm.AddTarget("tcp", wgData.TunnelIP, port, target)
}
err = pm.Start()
if err != nil {
log.Panic(err)
}
updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData)
}
})
@@ -249,51 +227,54 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
return
}
type TargetData struct {
Targets []string `json:"targets"`
}
jsonData, err := json.Marshal(msg.Data)
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error marshaling data: %v", err)
return
}
var targetData TargetData
if err := json.Unmarshal(jsonData, &targetData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
log.Printf("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
err = pm.Stop()
if err != nil {
log.Panic(err)
}
updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData)
}
})
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 2 {
log.Printf("Invalid target format: %s", t)
continue
}
client.RegisterHandler("newt/udp/remove", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
log.Printf("Invalid port: %s", parts[0])
continue
}
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
}
target := parts[1]
pm.AddTarget("udp", wgData.TunnelIP, port, target)
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error parsing target data: %v", err)
return
}
err = pm.Start()
if err != nil {
log.Panic(err)
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData)
}
})
client.RegisterHandler("newt/tcp/remove", func(msg websocket.WSMessage) {
log.Printf("Received: %+v", msg)
// if there is no wgData or pm, we can't add targets
if wgData.TunnelIP == "" || pm == nil {
log.Printf("No tunnel IP or proxy manager available")
return
}
targetData, err := parseTargetData(msg.Data)
if err != nil {
log.Printf("Error parsing target data: %v", err)
return
}
if len(targetData.Targets) > 0 {
updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData)
}
})
@@ -303,10 +284,9 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
}
defer client.Close()
// TODO: we need to send the public key to the server to trigger it to respond to create the tunnel
// TODO: how to retry?
err = client.SendMessage("newt/wg/register", map[string]interface{}{
"content": "Hello, World!",
"publicKey": fmt.Sprintf("%s", privateKey),
})
if err != nil {
log.Printf("Failed to send message: %v", err)
@@ -320,3 +300,58 @@ persistent_keepalive_interval=5`, fmt.Sprintf("%s", privateKey), fixKey(wgData.P
// Cleanup
dev.Close()
}
func parseTargetData(data interface{}) (TargetData, error) {
var targetData TargetData
jsonData, err := json.Marshal(data)
if err != nil {
log.Printf("Error marshaling data: %v", err)
return targetData, err
}
if err := json.Unmarshal(jsonData, &targetData); err != nil {
log.Printf("Error unmarshaling target data: %v", err)
return targetData, err
}
return targetData, nil
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
// Stop the proxy manager before adding new targets
err := pm.Stop()
if err != nil {
log.Panic(err)
}
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
if len(parts) != 2 {
log.Printf("Invalid target format: %s", t)
continue
}
// Get the port as an int
port := 0
_, err := fmt.Sscanf(parts[0], "%d", &port)
if err != nil {
log.Printf("Invalid port: %s", parts[0])
continue
}
if action == "add" {
target := parts[1]
pm.AddTarget(proto, tunnelIP, port, target)
} else if action == "remove" {
pm.RemoveTarget(proto, tunnelIP, port)
}
}
err = pm.Start()
if err != nil {
log.Panic(err)
}
return nil
}