Fix the bind problem by just recreating the dev

TODO: WHY CANT WE REBIND TO A PORT - WE NEED TO FIX THIS BETTER
This commit is contained in:
Owen
2025-07-29 20:58:48 -07:00
parent dfba35f8bb
commit 45d17da570
3 changed files with 185 additions and 14 deletions

View File

@@ -54,6 +54,13 @@ func setupClientsNetstack(client *websocket.Client, host string) {
}
})
wgService.SetOnNetstackClose(func() {
if wgTesterServer != nil {
wgTesterServer.Stop()
wgTesterServer = nil
}
})
client.OnTokenUpdate(func(token string) {
wgService.SetToken(token)
})

View File

@@ -191,13 +191,13 @@ func (pm *ProxyManager) Stop() error {
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
}
// Clear the target maps
for k := range pm.tcpTargets {
delete(pm.tcpTargets, k)
}
for k := range pm.udpTargets {
delete(pm.udpTargets, k)
}
// // Clear the target maps
// for k := range pm.tcpTargets {
// delete(pm.tcpTargets, k)
// }
// for k := range pm.udpTargets {
// delete(pm.udpTargets, k)
// }
// Give active connections a chance to close gracefully
time.Sleep(100 * time.Millisecond)
@@ -368,3 +368,23 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
}
}
}
// write a function to print out the current targets in the ProxyManager
func (pm *ProxyManager) PrintTargets() {
pm.mutex.RLock()
defer pm.mutex.RUnlock()
logger.Info("Current TCP Targets:")
for listenIP, targets := range pm.tcpTargets {
for port, targetAddr := range targets {
logger.Info("TCP %s:%d -> %s", listenIP, port, targetAddr)
}
}
logger.Info("Current UDP Targets:")
for listenIP, targets := range pm.udpTargets {
for port, targetAddr := range targets {
logger.Info("UDP %s:%d -> %s", listenIP, port, targetAddr)
}
}
}

View File

@@ -85,6 +85,8 @@ type WireGuardService struct {
dns []netip.Addr
// Callback for when netstack is ready
onNetstackReady func(*netstack.Net)
// Callback for when netstack is closed
onNetstackClose func()
othertnet *netstack.Net
// Proxy manager for tunnel
proxyManager *proxy.ProxyManager
@@ -254,7 +256,7 @@ func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) {
}
if len(targetData.Targets) > 0 {
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
}
}
@@ -274,7 +276,7 @@ func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) {
}
if len(targetData.Targets) > 0 {
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
}
}
@@ -294,7 +296,7 @@ func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) {
}
if len(targetData.Targets) > 0 {
updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
}
}
@@ -314,7 +316,7 @@ func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) {
}
if len(targetData.Targets) > 0 {
updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
}
}
@@ -392,6 +394,10 @@ func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) {
s.onNetstackReady = callback
}
func (s *WireGuardService) SetOnNetstackClose(callback func()) {
s.onNetstackClose = callback
}
func (s *WireGuardService) LoadRemoteConfig() error {
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
"publicKey": s.key.PublicKey().String(),
@@ -438,11 +444,11 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
// add the targets if there are any
if len(config.Targets.TCP) > 0 {
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
}
if len(config.Targets.UDP) > 0 {
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
}
// Create ProxyManager for this tunnel
@@ -1077,7 +1083,8 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
}
}
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
var replace = true
for _, t := range targetData.Targets {
// Split the first number off of the target with : separator and use as the port
parts := strings.Split(t, ":")
@@ -1106,6 +1113,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
// Ignore "target not found" errors as this is expected for new targets
if !strings.Contains(err.Error(), "target not found") {
logger.Error("Failed to remove existing target: %v", err)
} else {
replace = false // If we got here, it means the target didn't exist, so we can add it without replacing
}
}
@@ -1123,6 +1132,17 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
}
}
if replace {
// If we replaced any targets, we need to hot swap the netstack
if err := s.ReplaceNetstack(s.dns); err != nil {
logger.Error("Failed to replace netstack after updating targets: %v", err)
return err
}
logger.Info("Netstack replaced successfully after updating targets")
} else {
logger.Info("No targets updated, no netstack replacement needed")
}
return nil
}
@@ -1140,3 +1160,127 @@ func parseTargetData(data interface{}) (TargetData, error) {
}
return targetData, nil
}
// Add this method to WireGuardService
func (s *WireGuardService) ReplaceNetstack(newDNS []netip.Addr) error {
s.mu.Lock()
defer s.mu.Unlock()
if s.device == nil || s.tun == nil {
return fmt.Errorf("WireGuard device not initialized")
}
// Parse the current tunnel IP from the existing config
parts := strings.Split(s.config.IpAddress, "/")
if len(parts) != 2 {
return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress)
}
tunnelIP := netip.MustParseAddr(parts[0])
// Stop the proxy manager temporarily
s.proxyManager.Stop()
// Create new TUN device and netstack with new DNS
newTun, newTnet, err := netstack.CreateNetTUN(
[]netip.Addr{tunnelIP},
newDNS,
s.mtu)
if err != nil {
// Restart proxy manager with old tnet on failure
s.proxyManager.Start()
return fmt.Errorf("failed to create new TUN device: %v", err)
}
// Get current device config before closing
currentConfig, err := s.device.IpcGet()
if err != nil {
newTun.Close()
s.proxyManager.Start()
return fmt.Errorf("failed to get current device config: %v", err)
}
// Filter out read-only fields from the config
filteredConfig := s.filterReadOnlyFields(currentConfig)
// if onNetstackClose callback is set, call it
if s.onNetstackClose != nil {
s.onNetstackClose()
}
// Close old device (this closes the old TUN device)
s.device.Close()
// Update references
s.tun = newTun
s.tnet = newTnet
s.dns = newDNS
// Create new WireGuard device with same port
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
device.LogLevelSilent,
"wireguard: ",
))
// Restore the configuration (without read-only fields)
err = s.device.IpcSet(filteredConfig)
if err != nil {
return fmt.Errorf("failed to restore WireGuard configuration: %v", err)
}
// Bring up the device
err = s.device.Up()
if err != nil {
return fmt.Errorf("failed to bring up new WireGuard device: %v", err)
}
// Update proxy manager with new tnet and restart
s.proxyManager.SetTNet(s.tnet)
s.proxyManager.Start()
s.proxyManager.PrintTargets()
// Call the netstack ready callback if set
if s.onNetstackReady != nil {
go s.onNetstackReady(s.tnet)
}
logger.Info("Netstack replaced successfully with new DNS servers")
return nil
}
// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
func (s *WireGuardService) filterReadOnlyFields(config string) string {
lines := strings.Split(config, "\n")
var filteredLines []string
// List of read-only fields that should not be included in IpcSet
readOnlyFields := map[string]bool{
"last_handshake_time_sec": true,
"last_handshake_time_nsec": true,
"rx_bytes": true,
"tx_bytes": true,
"protocol_version": true,
}
for _, line := range lines {
if line == "" {
continue
}
// Check if this line contains a read-only field
isReadOnly := false
for field := range readOnlyFields {
if strings.HasPrefix(line, field+"=") {
isReadOnly = true
break
}
}
// Only include non-read-only lines
if !isReadOnly {
filteredLines = append(filteredLines, line)
}
}
return strings.Join(filteredLines, "\n")
}