From 499ebcd928e272bb82edec0acbdcc9b23e872fcf Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 25 Jul 2025 10:59:34 -0700 Subject: [PATCH] Maybe its working? --- linux.go | 11 +++++ wgnetstack/wgnetstack.go | 13 +++++ wgtester/wgtester.go | 102 +++++++++++++++++++++++++++++++++------ 3 files changed, 111 insertions(+), 15 deletions(-) diff --git a/linux.go b/linux.go index a769b5a..2a317f7 100644 --- a/linux.go +++ b/linux.go @@ -13,6 +13,7 @@ import ( // "github.com/fosrl/newt/wg" "github.com/fosrl/newt/wgnetstack" "github.com/fosrl/newt/wgtester" + "golang.zx2c4.com/wireguard/tun/netstack" ) var wgService *wgnetstack.WireGuardService @@ -40,6 +41,16 @@ func setupClients(client *websocket.Client) { logger.Error("Failed to start WireGuard tester server: %v", err) } + // Set up callback to restart wgtester with netstack when WireGuard is ready + wgService.SetOnNetstackReady(func(tnet *netstack.Net) { + logger.Info("WireGuard netstack is ready, restarting wgtester with netstack") + if err := wgTesterServer.RestartWithNetstack(tnet); err != nil { + logger.Error("Failed to restart wgtester with netstack: %v", err) + } else { + logger.Info("WGTester successfully restarted with netstack") + } + }) + client.OnTokenUpdate(func(token string) { wgService.SetToken(token) }) diff --git a/wgnetstack/wgnetstack.go b/wgnetstack/wgnetstack.go index 72da713..533f363 100644 --- a/wgnetstack/wgnetstack.go +++ b/wgnetstack/wgnetstack.go @@ -72,6 +72,8 @@ type WireGuardService struct { tnet *netstack.Net device *device.Device dns []netip.Addr + // Callback for when netstack is ready + onNetstackReady func(*netstack.Net) } // Add this type definition @@ -257,6 +259,11 @@ func (s *WireGuardService) GetPublicKey() wgtypes.Key { return s.key.PublicKey() } +// SetOnNetstackReady sets a callback function to be called when the netstack interface is ready +func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) { + s.onNetstackReady = callback +} + func (s *WireGuardService) LoadRemoteConfig() error { s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ "publicKey": s.key.PublicKey().String(), @@ -348,6 +355,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { } logger.Info("WireGuard netstack device created and configured") + + // Call the callback if it's set to notify that netstack is ready + if s.onNetstackReady != nil { + s.onNetstackReady(s.tnet) + } + return nil } diff --git a/wgtester/wgtester.go b/wgtester/wgtester.go index 4495e77..0035f05 100644 --- a/wgtester/wgtester.go +++ b/wgtester/wgtester.go @@ -8,6 +8,8 @@ import ( "time" "github.com/fosrl/newt/logger" + "golang.zx2c4.com/wireguard/tun/netstack" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" ) const ( @@ -26,7 +28,9 @@ const ( // Server handles listening for connection check requests using UDP type Server struct { - conn *net.UDPConn + conn net.Conn // Generic net.Conn interface (could be *net.UDPConn or *gonet.UDPConn) + udpConn *net.UDPConn // Regular UDP connection (when not using netstack) + netstackConn interface{} // Netstack UDP connection (when using netstack) serverAddr string serverPort uint16 shutdownCh chan struct{} @@ -34,6 +38,8 @@ type Server struct { runningLock sync.Mutex newtID string outputPrefix string + useNetstack bool + tnet interface{} // Will be *netstack.Net when using netstack } // NewServer creates a new connection test server using UDP @@ -44,6 +50,21 @@ func NewServer(serverAddr string, serverPort uint16, newtID string) *Server { shutdownCh: make(chan struct{}), newtID: newtID, outputPrefix: "[WGTester] ", + useNetstack: false, + tnet: nil, + } +} + +// NewServerWithNetstack creates a new connection test server using WireGuard netstack +func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack.Net) *Server { + return &Server{ + serverAddr: serverAddr, + serverPort: serverPort + 1, // use the next port for the server + shutdownCh: make(chan struct{}), + newtID: newtID, + outputPrefix: "[WGTester] ", + useNetstack: true, + tnet: tnet, } } @@ -59,18 +80,30 @@ func (s *Server) Start() error { //create the address to listen on addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort)) - // Create UDP address to listen on - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return err - } + if s.useNetstack && s.tnet != nil { + // Use WireGuard netstack + tnet := s.tnet.(*netstack.Net) + udpAddr := &net.UDPAddr{Port: int(s.serverPort)} + netstackConn, err := tnet.ListenUDP(udpAddr) + if err != nil { + return err + } + s.netstackConn = netstackConn + s.conn = netstackConn + } else { + // Use regular UDP socket + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return err + } - // Create UDP connection - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return err + udpConn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return err + } + s.udpConn = udpConn + s.conn = udpConn } - s.conn = conn s.isRunning = true go s.handleConnections() @@ -96,6 +129,26 @@ func (s *Server) Stop() { logger.Info(s.outputPrefix + "Server stopped") } +// RestartWithNetstack stops the current server and restarts it with netstack +func (s *Server) RestartWithNetstack(tnet *netstack.Net) error { + s.Stop() + + // Update configuration to use netstack + s.useNetstack = true + s.tnet = tnet + + // Clear previous connections + s.conn = nil + s.udpConn = nil + s.netstackConn = nil + + // Create new shutdown channel + s.shutdownCh = make(chan struct{}) + + // Restart the server + return s.Start() +} + // handleConnections processes incoming packets func (s *Server) handleConnections() { buffer := make([]byte, 2000) // Buffer large enough for any UDP packet @@ -112,8 +165,18 @@ func (s *Server) handleConnections() { continue } - // Read from UDP connection - n, addr, err := s.conn.ReadFromUDP(buffer) + // Read from UDP connection - handle both regular UDP and netstack UDP + var n int + var addr net.Addr + if s.useNetstack { + // Use netstack UDP connection + netstackConn := s.netstackConn.(*gonet.UDPConn) + n, addr, err = netstackConn.ReadFrom(buffer) + } else { + // Use regular UDP connection + n, addr, err = s.udpConn.ReadFromUDP(buffer) + } + if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { // Just a timeout, keep going @@ -158,8 +221,17 @@ func (s *Server) handleConnections() { // Log response being sent for debugging logger.Debug(s.outputPrefix+"Sending response to %s", addr.String()) - // Send the response packet directly to the source address - _, err = s.conn.WriteToUDP(responsePacket, addr) + // Send the response packet - handle both regular UDP and netstack UDP + if s.useNetstack { + // Use netstack UDP connection + netstackConn := s.netstackConn.(*gonet.UDPConn) + _, err = netstackConn.WriteTo(responsePacket, addr) + } else { + // Use regular UDP connection + udpAddr := addr.(*net.UDPAddr) + _, err = s.udpConn.WriteToUDP(responsePacket, udpAddr) + } + if err != nil { logger.Error(s.outputPrefix+"Error sending response: %v", err) } else {