mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
Fix the bind problem by just recreating the dev
TODO: WHY CANT WE REBIND TO A PORT - WE NEED TO FIX THIS BETTER
This commit is contained in:
@@ -54,6 +54,13 @@ func setupClientsNetstack(client *websocket.Client, host string) {
|
||||
}
|
||||
})
|
||||
|
||||
wgService.SetOnNetstackClose(func() {
|
||||
if wgTesterServer != nil {
|
||||
wgTesterServer.Stop()
|
||||
wgTesterServer = nil
|
||||
}
|
||||
})
|
||||
|
||||
client.OnTokenUpdate(func(token string) {
|
||||
wgService.SetToken(token)
|
||||
})
|
||||
|
||||
@@ -191,13 +191,13 @@ func (pm *ProxyManager) Stop() error {
|
||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||
}
|
||||
|
||||
// Clear the target maps
|
||||
for k := range pm.tcpTargets {
|
||||
delete(pm.tcpTargets, k)
|
||||
}
|
||||
for k := range pm.udpTargets {
|
||||
delete(pm.udpTargets, k)
|
||||
}
|
||||
// // Clear the target maps
|
||||
// for k := range pm.tcpTargets {
|
||||
// delete(pm.tcpTargets, k)
|
||||
// }
|
||||
// for k := range pm.udpTargets {
|
||||
// delete(pm.udpTargets, k)
|
||||
// }
|
||||
|
||||
// Give active connections a chance to close gracefully
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
@@ -368,3 +368,23 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// write a function to print out the current targets in the ProxyManager
|
||||
func (pm *ProxyManager) PrintTargets() {
|
||||
pm.mutex.RLock()
|
||||
defer pm.mutex.RUnlock()
|
||||
|
||||
logger.Info("Current TCP Targets:")
|
||||
for listenIP, targets := range pm.tcpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
logger.Info("TCP %s:%d -> %s", listenIP, port, targetAddr)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info("Current UDP Targets:")
|
||||
for listenIP, targets := range pm.udpTargets {
|
||||
for port, targetAddr := range targets {
|
||||
logger.Info("UDP %s:%d -> %s", listenIP, port, targetAddr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,6 +85,8 @@ type WireGuardService struct {
|
||||
dns []netip.Addr
|
||||
// Callback for when netstack is ready
|
||||
onNetstackReady func(*netstack.Net)
|
||||
// Callback for when netstack is closed
|
||||
onNetstackClose func()
|
||||
othertnet *netstack.Net
|
||||
// Proxy manager for tunnel
|
||||
proxyManager *proxy.ProxyManager
|
||||
@@ -254,7 +256,7 @@ func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
|
||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -274,7 +276,7 @@ func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
|
||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -294,7 +296,7 @@ func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
|
||||
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,7 +316,7 @@ func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) {
|
||||
}
|
||||
|
||||
if len(targetData.Targets) > 0 {
|
||||
updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
|
||||
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -392,6 +394,10 @@ func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) {
|
||||
s.onNetstackReady = callback
|
||||
}
|
||||
|
||||
func (s *WireGuardService) SetOnNetstackClose(callback func()) {
|
||||
s.onNetstackClose = callback
|
||||
}
|
||||
|
||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||
"publicKey": s.key.PublicKey().String(),
|
||||
@@ -438,11 +444,11 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
||||
|
||||
// add the targets if there are any
|
||||
if len(config.Targets.TCP) > 0 {
|
||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
|
||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
|
||||
}
|
||||
|
||||
if len(config.Targets.UDP) > 0 {
|
||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
|
||||
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
|
||||
}
|
||||
|
||||
// Create ProxyManager for this tunnel
|
||||
@@ -1077,7 +1083,8 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
|
||||
}
|
||||
}
|
||||
|
||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||
var replace = true
|
||||
for _, t := range targetData.Targets {
|
||||
// Split the first number off of the target with : separator and use as the port
|
||||
parts := strings.Split(t, ":")
|
||||
@@ -1106,6 +1113,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
// Ignore "target not found" errors as this is expected for new targets
|
||||
if !strings.Contains(err.Error(), "target not found") {
|
||||
logger.Error("Failed to remove existing target: %v", err)
|
||||
} else {
|
||||
replace = false // If we got here, it means the target didn't exist, so we can add it without replacing
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1123,6 +1132,17 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
||||
}
|
||||
}
|
||||
|
||||
if replace {
|
||||
// If we replaced any targets, we need to hot swap the netstack
|
||||
if err := s.ReplaceNetstack(s.dns); err != nil {
|
||||
logger.Error("Failed to replace netstack after updating targets: %v", err)
|
||||
return err
|
||||
}
|
||||
logger.Info("Netstack replaced successfully after updating targets")
|
||||
} else {
|
||||
logger.Info("No targets updated, no netstack replacement needed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1140,3 +1160,127 @@ func parseTargetData(data interface{}) (TargetData, error) {
|
||||
}
|
||||
return targetData, nil
|
||||
}
|
||||
|
||||
// Add this method to WireGuardService
|
||||
func (s *WireGuardService) ReplaceNetstack(newDNS []netip.Addr) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.device == nil || s.tun == nil {
|
||||
return fmt.Errorf("WireGuard device not initialized")
|
||||
}
|
||||
|
||||
// Parse the current tunnel IP from the existing config
|
||||
parts := strings.Split(s.config.IpAddress, "/")
|
||||
if len(parts) != 2 {
|
||||
return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress)
|
||||
}
|
||||
tunnelIP := netip.MustParseAddr(parts[0])
|
||||
|
||||
// Stop the proxy manager temporarily
|
||||
s.proxyManager.Stop()
|
||||
|
||||
// Create new TUN device and netstack with new DNS
|
||||
newTun, newTnet, err := netstack.CreateNetTUN(
|
||||
[]netip.Addr{tunnelIP},
|
||||
newDNS,
|
||||
s.mtu)
|
||||
if err != nil {
|
||||
// Restart proxy manager with old tnet on failure
|
||||
s.proxyManager.Start()
|
||||
return fmt.Errorf("failed to create new TUN device: %v", err)
|
||||
}
|
||||
|
||||
// Get current device config before closing
|
||||
currentConfig, err := s.device.IpcGet()
|
||||
if err != nil {
|
||||
newTun.Close()
|
||||
s.proxyManager.Start()
|
||||
return fmt.Errorf("failed to get current device config: %v", err)
|
||||
}
|
||||
|
||||
// Filter out read-only fields from the config
|
||||
filteredConfig := s.filterReadOnlyFields(currentConfig)
|
||||
|
||||
// if onNetstackClose callback is set, call it
|
||||
if s.onNetstackClose != nil {
|
||||
s.onNetstackClose()
|
||||
}
|
||||
|
||||
// Close old device (this closes the old TUN device)
|
||||
s.device.Close()
|
||||
|
||||
// Update references
|
||||
s.tun = newTun
|
||||
s.tnet = newTnet
|
||||
s.dns = newDNS
|
||||
|
||||
// Create new WireGuard device with same port
|
||||
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
|
||||
device.LogLevelSilent,
|
||||
"wireguard: ",
|
||||
))
|
||||
|
||||
// Restore the configuration (without read-only fields)
|
||||
err = s.device.IpcSet(filteredConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to restore WireGuard configuration: %v", err)
|
||||
}
|
||||
|
||||
// Bring up the device
|
||||
err = s.device.Up()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bring up new WireGuard device: %v", err)
|
||||
}
|
||||
|
||||
// Update proxy manager with new tnet and restart
|
||||
s.proxyManager.SetTNet(s.tnet)
|
||||
s.proxyManager.Start()
|
||||
|
||||
s.proxyManager.PrintTargets()
|
||||
|
||||
// Call the netstack ready callback if set
|
||||
if s.onNetstackReady != nil {
|
||||
go s.onNetstackReady(s.tnet)
|
||||
}
|
||||
|
||||
logger.Info("Netstack replaced successfully with new DNS servers")
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
|
||||
func (s *WireGuardService) filterReadOnlyFields(config string) string {
|
||||
lines := strings.Split(config, "\n")
|
||||
var filteredLines []string
|
||||
|
||||
// List of read-only fields that should not be included in IpcSet
|
||||
readOnlyFields := map[string]bool{
|
||||
"last_handshake_time_sec": true,
|
||||
"last_handshake_time_nsec": true,
|
||||
"rx_bytes": true,
|
||||
"tx_bytes": true,
|
||||
"protocol_version": true,
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if this line contains a read-only field
|
||||
isReadOnly := false
|
||||
for field := range readOnlyFields {
|
||||
if strings.HasPrefix(line, field+"=") {
|
||||
isReadOnly = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Only include non-read-only lines
|
||||
if !isReadOnly {
|
||||
filteredLines = append(filteredLines, line)
|
||||
}
|
||||
}
|
||||
|
||||
return strings.Join(filteredLines, "\n")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user