diff --git a/clients.go b/clients.go index 78f7844..60bbf41 100644 --- a/clients.go +++ b/clients.go @@ -54,6 +54,13 @@ func setupClientsNetstack(client *websocket.Client, host string) { } }) + wgService.SetOnNetstackClose(func() { + if wgTesterServer != nil { + wgTesterServer.Stop() + wgTesterServer = nil + } + }) + client.OnTokenUpdate(func(token string) { wgService.SetToken(token) }) diff --git a/proxy/manager.go b/proxy/manager.go index 43a7da7..4d14582 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -191,13 +191,13 @@ func (pm *ProxyManager) Stop() error { pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) } - // Clear the target maps - for k := range pm.tcpTargets { - delete(pm.tcpTargets, k) - } - for k := range pm.udpTargets { - delete(pm.udpTargets, k) - } + // // Clear the target maps + // for k := range pm.tcpTargets { + // delete(pm.tcpTargets, k) + // } + // for k := range pm.udpTargets { + // delete(pm.udpTargets, k) + // } // Give active connections a chance to close gracefully time.Sleep(100 * time.Millisecond) @@ -368,3 +368,23 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { } } } + +// write a function to print out the current targets in the ProxyManager +func (pm *ProxyManager) PrintTargets() { + pm.mutex.RLock() + defer pm.mutex.RUnlock() + + logger.Info("Current TCP Targets:") + for listenIP, targets := range pm.tcpTargets { + for port, targetAddr := range targets { + logger.Info("TCP %s:%d -> %s", listenIP, port, targetAddr) + } + } + + logger.Info("Current UDP Targets:") + for listenIP, targets := range pm.udpTargets { + for port, targetAddr := range targets { + logger.Info("UDP %s:%d -> %s", listenIP, port, targetAddr) + } + } +} diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index a61a018..a76e2ee 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -85,6 +85,8 @@ type WireGuardService struct { dns []netip.Addr // Callback for when netstack is ready onNetstackReady func(*netstack.Net) + // Callback for when netstack is closed + onNetstackClose func() othertnet *netstack.Net // Proxy manager for tunnel proxyManager *proxy.ProxyManager @@ -254,7 +256,7 @@ func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) { } if len(targetData.Targets) > 0 { - updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData) + s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData) } } @@ -274,7 +276,7 @@ func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) { } if len(targetData.Targets) > 0 { - updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData) + s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData) } } @@ -294,7 +296,7 @@ func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) { } if len(targetData.Targets) > 0 { - updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData) + s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData) } } @@ -314,7 +316,7 @@ func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) { } if len(targetData.Targets) > 0 { - updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData) + s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData) } } @@ -392,6 +394,10 @@ func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) { s.onNetstackReady = callback } +func (s *WireGuardService) SetOnNetstackClose(callback func()) { + s.onNetstackClose = callback +} + func (s *WireGuardService) LoadRemoteConfig() error { s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ "publicKey": s.key.PublicKey().String(), @@ -438,11 +444,11 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { // add the targets if there are any if len(config.Targets.TCP) > 0 { - updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP}) + s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP}) } if len(config.Targets.UDP) > 0 { - updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP}) + s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP}) } // Create ProxyManager for this tunnel @@ -1077,7 +1083,8 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) { } } -func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { +func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error { + var replace = true for _, t := range targetData.Targets { // Split the first number off of the target with : separator and use as the port parts := strings.Split(t, ":") @@ -1106,6 +1113,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto // Ignore "target not found" errors as this is expected for new targets if !strings.Contains(err.Error(), "target not found") { logger.Error("Failed to remove existing target: %v", err) + } else { + replace = false // If we got here, it means the target didn't exist, so we can add it without replacing } } @@ -1123,6 +1132,17 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto } } + if replace { + // If we replaced any targets, we need to hot swap the netstack + if err := s.ReplaceNetstack(s.dns); err != nil { + logger.Error("Failed to replace netstack after updating targets: %v", err) + return err + } + logger.Info("Netstack replaced successfully after updating targets") + } else { + logger.Info("No targets updated, no netstack replacement needed") + } + return nil } @@ -1140,3 +1160,127 @@ func parseTargetData(data interface{}) (TargetData, error) { } return targetData, nil } + +// Add this method to WireGuardService +func (s *WireGuardService) ReplaceNetstack(newDNS []netip.Addr) error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.device == nil || s.tun == nil { + return fmt.Errorf("WireGuard device not initialized") + } + + // Parse the current tunnel IP from the existing config + parts := strings.Split(s.config.IpAddress, "/") + if len(parts) != 2 { + return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress) + } + tunnelIP := netip.MustParseAddr(parts[0]) + + // Stop the proxy manager temporarily + s.proxyManager.Stop() + + // Create new TUN device and netstack with new DNS + newTun, newTnet, err := netstack.CreateNetTUN( + []netip.Addr{tunnelIP}, + newDNS, + s.mtu) + if err != nil { + // Restart proxy manager with old tnet on failure + s.proxyManager.Start() + return fmt.Errorf("failed to create new TUN device: %v", err) + } + + // Get current device config before closing + currentConfig, err := s.device.IpcGet() + if err != nil { + newTun.Close() + s.proxyManager.Start() + return fmt.Errorf("failed to get current device config: %v", err) + } + + // Filter out read-only fields from the config + filteredConfig := s.filterReadOnlyFields(currentConfig) + + // if onNetstackClose callback is set, call it + if s.onNetstackClose != nil { + s.onNetstackClose() + } + + // Close old device (this closes the old TUN device) + s.device.Close() + + // Update references + s.tun = newTun + s.tnet = newTnet + s.dns = newDNS + + // Create new WireGuard device with same port + s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger( + device.LogLevelSilent, + "wireguard: ", + )) + + // Restore the configuration (without read-only fields) + err = s.device.IpcSet(filteredConfig) + if err != nil { + return fmt.Errorf("failed to restore WireGuard configuration: %v", err) + } + + // Bring up the device + err = s.device.Up() + if err != nil { + return fmt.Errorf("failed to bring up new WireGuard device: %v", err) + } + + // Update proxy manager with new tnet and restart + s.proxyManager.SetTNet(s.tnet) + s.proxyManager.Start() + + s.proxyManager.PrintTargets() + + // Call the netstack ready callback if set + if s.onNetstackReady != nil { + go s.onNetstackReady(s.tnet) + } + + logger.Info("Netstack replaced successfully with new DNS servers") + return nil +} + +// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration +func (s *WireGuardService) filterReadOnlyFields(config string) string { + lines := strings.Split(config, "\n") + var filteredLines []string + + // List of read-only fields that should not be included in IpcSet + readOnlyFields := map[string]bool{ + "last_handshake_time_sec": true, + "last_handshake_time_nsec": true, + "rx_bytes": true, + "tx_bytes": true, + "protocol_version": true, + } + + for _, line := range lines { + if line == "" { + continue + } + + // Check if this line contains a read-only field + isReadOnly := false + for field := range readOnlyFields { + if strings.HasPrefix(line, field+"=") { + isReadOnly = true + break + } + } + + // Only include non-read-only lines + if !isReadOnly { + filteredLines = append(filteredLines, line) + } + } + + return strings.Join(filteredLines, "\n") +}