diff --git a/main.go b/main.go index ff85cc2..6943668 100644 --- a/main.go +++ b/main.go @@ -420,7 +420,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub if len(wgData.Targets.TCP) > 0 { updateTargets(pm, "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP}) // Also update wgnetstack proxy manager - if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { + if wgService != nil { updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "tcp", TargetData{Targets: wgData.Targets.TCP}) } } @@ -428,7 +428,7 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub if len(wgData.Targets.UDP) > 0 { updateTargets(pm, "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) // Also update wgnetstack proxy manager - if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { + if wgService != nil { updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "udp", TargetData{Targets: wgData.Targets.UDP}) } } @@ -647,9 +647,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub updateTargets(pm, "add", wgData.TunnelIP, "tcp", targetData) // Also update wgnetstack proxy manager - if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { - updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "tcp", targetData) - } + // if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { + // updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "tcp", targetData) + // } } }) @@ -672,9 +672,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub updateTargets(pm, "add", wgData.TunnelIP, "udp", targetData) // Also update wgnetstack proxy manager - if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { - updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "udp", targetData) - } + // if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { + // updateTargets(wgService.GetProxyManager(), "add", wgData.TunnelIP, "udp", targetData) + // } } }) @@ -697,9 +697,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub updateTargets(pm, "remove", wgData.TunnelIP, "udp", targetData) // Also update wgnetstack proxy manager - if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { - updateTargets(wgService.GetProxyManager(), "remove", wgData.TunnelIP, "udp", targetData) - } + // if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { + // updateTargets(wgService.GetProxyManager(), "remove", wgData.TunnelIP, "udp", targetData) + // } } }) @@ -722,9 +722,9 @@ persistent_keepalive_interval=5`, fixKey(privateKey.String()), fixKey(wgData.Pub updateTargets(pm, "remove", wgData.TunnelIP, "tcp", targetData) // Also update wgnetstack proxy manager - if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { - updateTargets(wgService.GetProxyManager(), "remove", wgData.TunnelIP, "tcp", targetData) - } + // if wgService != nil && wgService.GetNetstackNet() != nil && wgService.GetProxyManager() != nil { + // updateTargets(wgService.GetProxyManager(), "remove", wgData.TunnelIP, "tcp", targetData) + // } } }) diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index b5bc9b8..015bf15 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -29,8 +29,18 @@ import ( ) type WgConfig struct { - IpAddress string `json:"ipAddress"` - Peers []Peer `json:"peers"` + IpAddress string `json:"ipAddress"` + Peers []Peer `json:"peers"` + Targets TargetsByType `json:"targets"` +} + +type TargetsByType struct { + UDP []string `json:"udp"` + TCP []string `json:"tcp"` +} + +type TargetData struct { + Targets []string `json:"targets"` } type Peer struct { @@ -348,6 +358,18 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { if err := s.ensureWireguardPeers(config.Peers); err != nil { logger.Error("Failed to ensure WireGuard peers: %v", err) } + + // add the targets if there are any + if len(config.Targets.TCP) > 0 { + updateTargets(s.proxyManager, "add", config.IpAddress, "tcp", TargetData{Targets: config.Targets.TCP}) + } + + if len(config.Targets.UDP) > 0 { + updateTargets(s.proxyManager, "add", config.IpAddress, "udp", TargetData{Targets: config.Targets.UDP}) + } + + // Create ProxyManager for this tunnel + s.proxyManager.Start() } func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { @@ -410,9 +432,6 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { logger.Info("WireGuard netstack device created and configured") - // Create ProxyManager for this tunnel - s.proxyManager.Start() - // Store callback and tnet reference before releasing mutex callback := s.onNetstackReady tnet := s.tnet @@ -967,3 +986,52 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) { } } } + +func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { + for _, t := range targetData.Targets { + // Split the first number off of the target with : separator and use as the port + parts := strings.Split(t, ":") + if len(parts) != 3 { + logger.Info("Invalid target format: %s", t) + continue + } + + // Get the port as an int + port := 0 + _, err := fmt.Sscanf(parts[0], "%d", &port) + if err != nil { + logger.Info("Invalid port: %s", parts[0]) + continue + } + + if action == "add" { + target := parts[1] + ":" + parts[2] + + // Call updown script if provided + processedTarget := target + + // Only remove the specific target if it exists + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + // Ignore "target not found" errors as this is expected for new targets + if !strings.Contains(err.Error(), "target not found") { + logger.Error("Failed to remove existing target: %v", err) + } + } + + // Add the new target + pm.AddTarget(proto, tunnelIP, port, processedTarget) + + } else if action == "remove" { + logger.Info("Removing target with port %d", port) + + err := pm.RemoveTarget(proto, tunnelIP, port) + if err != nil { + logger.Error("Failed to remove target: %v", err) + return err + } + } + } + + return nil +}