mirror of
https://github.com/fosrl/newt.git
synced 2026-03-10 12:46:38 +00:00
Fix the bind problem by just recreating the dev
TODO: WHY CANT WE REBIND TO A PORT - WE NEED TO FIX THIS BETTER
This commit is contained in:
@@ -54,6 +54,13 @@ func setupClientsNetstack(client *websocket.Client, host string) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
wgService.SetOnNetstackClose(func() {
|
||||||
|
if wgTesterServer != nil {
|
||||||
|
wgTesterServer.Stop()
|
||||||
|
wgTesterServer = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
client.OnTokenUpdate(func(token string) {
|
client.OnTokenUpdate(func(token string) {
|
||||||
wgService.SetToken(token)
|
wgService.SetToken(token)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -191,13 +191,13 @@ func (pm *ProxyManager) Stop() error {
|
|||||||
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clear the target maps
|
// // Clear the target maps
|
||||||
for k := range pm.tcpTargets {
|
// for k := range pm.tcpTargets {
|
||||||
delete(pm.tcpTargets, k)
|
// delete(pm.tcpTargets, k)
|
||||||
}
|
// }
|
||||||
for k := range pm.udpTargets {
|
// for k := range pm.udpTargets {
|
||||||
delete(pm.udpTargets, k)
|
// delete(pm.udpTargets, k)
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Give active connections a chance to close gracefully
|
// Give active connections a chance to close gracefully
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -368,3 +368,23 @@ func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// write a function to print out the current targets in the ProxyManager
|
||||||
|
func (pm *ProxyManager) PrintTargets() {
|
||||||
|
pm.mutex.RLock()
|
||||||
|
defer pm.mutex.RUnlock()
|
||||||
|
|
||||||
|
logger.Info("Current TCP Targets:")
|
||||||
|
for listenIP, targets := range pm.tcpTargets {
|
||||||
|
for port, targetAddr := range targets {
|
||||||
|
logger.Info("TCP %s:%d -> %s", listenIP, port, targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Current UDP Targets:")
|
||||||
|
for listenIP, targets := range pm.udpTargets {
|
||||||
|
for port, targetAddr := range targets {
|
||||||
|
logger.Info("UDP %s:%d -> %s", listenIP, port, targetAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -85,6 +85,8 @@ type WireGuardService struct {
|
|||||||
dns []netip.Addr
|
dns []netip.Addr
|
||||||
// Callback for when netstack is ready
|
// Callback for when netstack is ready
|
||||||
onNetstackReady func(*netstack.Net)
|
onNetstackReady func(*netstack.Net)
|
||||||
|
// Callback for when netstack is closed
|
||||||
|
onNetstackClose func()
|
||||||
othertnet *netstack.Net
|
othertnet *netstack.Net
|
||||||
// Proxy manager for tunnel
|
// Proxy manager for tunnel
|
||||||
proxyManager *proxy.ProxyManager
|
proxyManager *proxy.ProxyManager
|
||||||
@@ -254,7 +256,7 @@ func (s *WireGuardService) addTcpTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
|
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", targetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -274,7 +276,7 @@ func (s *WireGuardService) addUdpTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
|
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", targetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,7 +296,7 @@ func (s *WireGuardService) removeUdpTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
|
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "udp", targetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -314,7 +316,7 @@ func (s *WireGuardService) removeTcpTarget(msg websocket.WSMessage) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(targetData.Targets) > 0 {
|
if len(targetData.Targets) > 0 {
|
||||||
updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
|
s.updateTargets(s.proxyManager, "remove", s.TunnelIP, "tcp", targetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -392,6 +394,10 @@ func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) {
|
|||||||
s.onNetstackReady = callback
|
s.onNetstackReady = callback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *WireGuardService) SetOnNetstackClose(callback func()) {
|
||||||
|
s.onNetstackClose = callback
|
||||||
|
}
|
||||||
|
|
||||||
func (s *WireGuardService) LoadRemoteConfig() error {
|
func (s *WireGuardService) LoadRemoteConfig() error {
|
||||||
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
|
||||||
"publicKey": s.key.PublicKey().String(),
|
"publicKey": s.key.PublicKey().String(),
|
||||||
@@ -438,11 +444,11 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) {
|
|||||||
|
|
||||||
// add the targets if there are any
|
// add the targets if there are any
|
||||||
if len(config.Targets.TCP) > 0 {
|
if len(config.Targets.TCP) > 0 {
|
||||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
|
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "tcp", TargetData{Targets: config.Targets.TCP})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(config.Targets.UDP) > 0 {
|
if len(config.Targets.UDP) > 0 {
|
||||||
updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
|
s.updateTargets(s.proxyManager, "add", s.TunnelIP, "udp", TargetData{Targets: config.Targets.UDP})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create ProxyManager for this tunnel
|
// Create ProxyManager for this tunnel
|
||||||
@@ -1077,7 +1083,8 @@ func (s *WireGuardService) keepSendingUDPHolePunch(host string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
func (s *WireGuardService) updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto string, targetData TargetData) error {
|
||||||
|
var replace = true
|
||||||
for _, t := range targetData.Targets {
|
for _, t := range targetData.Targets {
|
||||||
// Split the first number off of the target with : separator and use as the port
|
// Split the first number off of the target with : separator and use as the port
|
||||||
parts := strings.Split(t, ":")
|
parts := strings.Split(t, ":")
|
||||||
@@ -1106,6 +1113,8 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
|||||||
// Ignore "target not found" errors as this is expected for new targets
|
// Ignore "target not found" errors as this is expected for new targets
|
||||||
if !strings.Contains(err.Error(), "target not found") {
|
if !strings.Contains(err.Error(), "target not found") {
|
||||||
logger.Error("Failed to remove existing target: %v", err)
|
logger.Error("Failed to remove existing target: %v", err)
|
||||||
|
} else {
|
||||||
|
replace = false // If we got here, it means the target didn't exist, so we can add it without replacing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1123,6 +1132,17 @@ func updateTargets(pm *proxy.ProxyManager, action string, tunnelIP string, proto
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if replace {
|
||||||
|
// If we replaced any targets, we need to hot swap the netstack
|
||||||
|
if err := s.ReplaceNetstack(s.dns); err != nil {
|
||||||
|
logger.Error("Failed to replace netstack after updating targets: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
logger.Info("Netstack replaced successfully after updating targets")
|
||||||
|
} else {
|
||||||
|
logger.Info("No targets updated, no netstack replacement needed")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1140,3 +1160,127 @@ func parseTargetData(data interface{}) (TargetData, error) {
|
|||||||
}
|
}
|
||||||
return targetData, nil
|
return targetData, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add this method to WireGuardService
|
||||||
|
func (s *WireGuardService) ReplaceNetstack(newDNS []netip.Addr) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.device == nil || s.tun == nil {
|
||||||
|
return fmt.Errorf("WireGuard device not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the current tunnel IP from the existing config
|
||||||
|
parts := strings.Split(s.config.IpAddress, "/")
|
||||||
|
if len(parts) != 2 {
|
||||||
|
return fmt.Errorf("invalid IP address format: %s", s.config.IpAddress)
|
||||||
|
}
|
||||||
|
tunnelIP := netip.MustParseAddr(parts[0])
|
||||||
|
|
||||||
|
// Stop the proxy manager temporarily
|
||||||
|
s.proxyManager.Stop()
|
||||||
|
|
||||||
|
// Create new TUN device and netstack with new DNS
|
||||||
|
newTun, newTnet, err := netstack.CreateNetTUN(
|
||||||
|
[]netip.Addr{tunnelIP},
|
||||||
|
newDNS,
|
||||||
|
s.mtu)
|
||||||
|
if err != nil {
|
||||||
|
// Restart proxy manager with old tnet on failure
|
||||||
|
s.proxyManager.Start()
|
||||||
|
return fmt.Errorf("failed to create new TUN device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get current device config before closing
|
||||||
|
currentConfig, err := s.device.IpcGet()
|
||||||
|
if err != nil {
|
||||||
|
newTun.Close()
|
||||||
|
s.proxyManager.Start()
|
||||||
|
return fmt.Errorf("failed to get current device config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out read-only fields from the config
|
||||||
|
filteredConfig := s.filterReadOnlyFields(currentConfig)
|
||||||
|
|
||||||
|
// if onNetstackClose callback is set, call it
|
||||||
|
if s.onNetstackClose != nil {
|
||||||
|
s.onNetstackClose()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close old device (this closes the old TUN device)
|
||||||
|
s.device.Close()
|
||||||
|
|
||||||
|
// Update references
|
||||||
|
s.tun = newTun
|
||||||
|
s.tnet = newTnet
|
||||||
|
s.dns = newDNS
|
||||||
|
|
||||||
|
// Create new WireGuard device with same port
|
||||||
|
s.device = device.NewDevice(s.tun, NewFixedPortBind(s.Port), device.NewLogger(
|
||||||
|
device.LogLevelSilent,
|
||||||
|
"wireguard: ",
|
||||||
|
))
|
||||||
|
|
||||||
|
// Restore the configuration (without read-only fields)
|
||||||
|
err = s.device.IpcSet(filteredConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to restore WireGuard configuration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bring up the device
|
||||||
|
err = s.device.Up()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to bring up new WireGuard device: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update proxy manager with new tnet and restart
|
||||||
|
s.proxyManager.SetTNet(s.tnet)
|
||||||
|
s.proxyManager.Start()
|
||||||
|
|
||||||
|
s.proxyManager.PrintTargets()
|
||||||
|
|
||||||
|
// Call the netstack ready callback if set
|
||||||
|
if s.onNetstackReady != nil {
|
||||||
|
go s.onNetstackReady(s.tnet)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Info("Netstack replaced successfully with new DNS servers")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterReadOnlyFields removes read-only fields from WireGuard IPC configuration
|
||||||
|
func (s *WireGuardService) filterReadOnlyFields(config string) string {
|
||||||
|
lines := strings.Split(config, "\n")
|
||||||
|
var filteredLines []string
|
||||||
|
|
||||||
|
// List of read-only fields that should not be included in IpcSet
|
||||||
|
readOnlyFields := map[string]bool{
|
||||||
|
"last_handshake_time_sec": true,
|
||||||
|
"last_handshake_time_nsec": true,
|
||||||
|
"rx_bytes": true,
|
||||||
|
"tx_bytes": true,
|
||||||
|
"protocol_version": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, line := range lines {
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if this line contains a read-only field
|
||||||
|
isReadOnly := false
|
||||||
|
for field := range readOnlyFields {
|
||||||
|
if strings.HasPrefix(line, field+"=") {
|
||||||
|
isReadOnly = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only include non-read-only lines
|
||||||
|
if !isReadOnly {
|
||||||
|
filteredLines = append(filteredLines, line)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(filteredLines, "\n")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user