Maybe its working?

This commit is contained in:
Owen
2025-07-25 10:59:34 -07:00
parent 40dfab31a5
commit 499ebcd928
3 changed files with 111 additions and 15 deletions

View File

@@ -13,6 +13,7 @@ import (
// "github.com/fosrl/newt/wg" // "github.com/fosrl/newt/wg"
"github.com/fosrl/newt/wgnetstack" "github.com/fosrl/newt/wgnetstack"
"github.com/fosrl/newt/wgtester" "github.com/fosrl/newt/wgtester"
"golang.zx2c4.com/wireguard/tun/netstack"
) )
var wgService *wgnetstack.WireGuardService var wgService *wgnetstack.WireGuardService
@@ -40,6 +41,16 @@ func setupClients(client *websocket.Client) {
logger.Error("Failed to start WireGuard tester server: %v", err) logger.Error("Failed to start WireGuard tester server: %v", err)
} }
// Set up callback to restart wgtester with netstack when WireGuard is ready
wgService.SetOnNetstackReady(func(tnet *netstack.Net) {
logger.Info("WireGuard netstack is ready, restarting wgtester with netstack")
if err := wgTesterServer.RestartWithNetstack(tnet); err != nil {
logger.Error("Failed to restart wgtester with netstack: %v", err)
} else {
logger.Info("WGTester successfully restarted with netstack")
}
})
client.OnTokenUpdate(func(token string) { client.OnTokenUpdate(func(token string) {
wgService.SetToken(token) wgService.SetToken(token)
}) })

View File

@@ -72,6 +72,8 @@ type WireGuardService struct {
tnet *netstack.Net tnet *netstack.Net
device *device.Device device *device.Device
dns []netip.Addr dns []netip.Addr
// Callback for when netstack is ready
onNetstackReady func(*netstack.Net)
} }
// Add this type definition // Add this type definition
@@ -257,6 +259,11 @@ func (s *WireGuardService) GetPublicKey() wgtypes.Key {
return s.key.PublicKey() return s.key.PublicKey()
} }
// SetOnNetstackReady sets a callback function to be called when the netstack interface is ready
func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) {
s.onNetstackReady = callback
}
func (s *WireGuardService) LoadRemoteConfig() error { func (s *WireGuardService) LoadRemoteConfig() error {
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{ s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
"publicKey": s.key.PublicKey().String(), "publicKey": s.key.PublicKey().String(),
@@ -348,6 +355,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
} }
logger.Info("WireGuard netstack device created and configured") logger.Info("WireGuard netstack device created and configured")
// Call the callback if it's set to notify that netstack is ready
if s.onNetstackReady != nil {
s.onNetstackReady(s.tnet)
}
return nil return nil
} }

View File

@@ -8,6 +8,8 @@ import (
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun/netstack"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
) )
const ( const (
@@ -26,7 +28,9 @@ const (
// Server handles listening for connection check requests using UDP // Server handles listening for connection check requests using UDP
type Server struct { type Server struct {
conn *net.UDPConn conn net.Conn // Generic net.Conn interface (could be *net.UDPConn or *gonet.UDPConn)
udpConn *net.UDPConn // Regular UDP connection (when not using netstack)
netstackConn interface{} // Netstack UDP connection (when using netstack)
serverAddr string serverAddr string
serverPort uint16 serverPort uint16
shutdownCh chan struct{} shutdownCh chan struct{}
@@ -34,6 +38,8 @@ type Server struct {
runningLock sync.Mutex runningLock sync.Mutex
newtID string newtID string
outputPrefix string outputPrefix string
useNetstack bool
tnet interface{} // Will be *netstack.Net when using netstack
} }
// NewServer creates a new connection test server using UDP // NewServer creates a new connection test server using UDP
@@ -44,6 +50,21 @@ func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
shutdownCh: make(chan struct{}), shutdownCh: make(chan struct{}),
newtID: newtID, newtID: newtID,
outputPrefix: "[WGTester] ", outputPrefix: "[WGTester] ",
useNetstack: false,
tnet: nil,
}
}
// NewServerWithNetstack creates a new connection test server using WireGuard netstack
func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack.Net) *Server {
return &Server{
serverAddr: serverAddr,
serverPort: serverPort + 1, // use the next port for the server
shutdownCh: make(chan struct{}),
newtID: newtID,
outputPrefix: "[WGTester] ",
useNetstack: true,
tnet: tnet,
} }
} }
@@ -59,18 +80,30 @@ func (s *Server) Start() error {
//create the address to listen on //create the address to listen on
addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort)) addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort))
// Create UDP address to listen on if s.useNetstack && s.tnet != nil {
udpAddr, err := net.ResolveUDPAddr("udp", addr) // Use WireGuard netstack
if err != nil { tnet := s.tnet.(*netstack.Net)
return err udpAddr := &net.UDPAddr{Port: int(s.serverPort)}
} netstackConn, err := tnet.ListenUDP(udpAddr)
if err != nil {
return err
}
s.netstackConn = netstackConn
s.conn = netstackConn
} else {
// Use regular UDP socket
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
// Create UDP connection udpConn, err := net.ListenUDP("udp", udpAddr)
conn, err := net.ListenUDP("udp", udpAddr) if err != nil {
if err != nil { return err
return err }
s.udpConn = udpConn
s.conn = udpConn
} }
s.conn = conn
s.isRunning = true s.isRunning = true
go s.handleConnections() go s.handleConnections()
@@ -96,6 +129,26 @@ func (s *Server) Stop() {
logger.Info(s.outputPrefix + "Server stopped") logger.Info(s.outputPrefix + "Server stopped")
} }
// RestartWithNetstack stops the current server and restarts it with netstack
func (s *Server) RestartWithNetstack(tnet *netstack.Net) error {
s.Stop()
// Update configuration to use netstack
s.useNetstack = true
s.tnet = tnet
// Clear previous connections
s.conn = nil
s.udpConn = nil
s.netstackConn = nil
// Create new shutdown channel
s.shutdownCh = make(chan struct{})
// Restart the server
return s.Start()
}
// handleConnections processes incoming packets // handleConnections processes incoming packets
func (s *Server) handleConnections() { func (s *Server) handleConnections() {
buffer := make([]byte, 2000) // Buffer large enough for any UDP packet buffer := make([]byte, 2000) // Buffer large enough for any UDP packet
@@ -112,8 +165,18 @@ func (s *Server) handleConnections() {
continue continue
} }
// Read from UDP connection // Read from UDP connection - handle both regular UDP and netstack UDP
n, addr, err := s.conn.ReadFromUDP(buffer) var n int
var addr net.Addr
if s.useNetstack {
// Use netstack UDP connection
netstackConn := s.netstackConn.(*gonet.UDPConn)
n, addr, err = netstackConn.ReadFrom(buffer)
} else {
// Use regular UDP connection
n, addr, err = s.udpConn.ReadFromUDP(buffer)
}
if err != nil { if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Just a timeout, keep going // Just a timeout, keep going
@@ -158,8 +221,17 @@ func (s *Server) handleConnections() {
// Log response being sent for debugging // Log response being sent for debugging
logger.Debug(s.outputPrefix+"Sending response to %s", addr.String()) logger.Debug(s.outputPrefix+"Sending response to %s", addr.String())
// Send the response packet directly to the source address // Send the response packet - handle both regular UDP and netstack UDP
_, err = s.conn.WriteToUDP(responsePacket, addr) if s.useNetstack {
// Use netstack UDP connection
netstackConn := s.netstackConn.(*gonet.UDPConn)
_, err = netstackConn.WriteTo(responsePacket, addr)
} else {
// Use regular UDP connection
udpAddr := addr.(*net.UDPAddr)
_, err = s.udpConn.WriteToUDP(responsePacket, udpAddr)
}
if err != nil { if err != nil {
logger.Error(s.outputPrefix+"Error sending response: %v", err) logger.Error(s.outputPrefix+"Error sending response: %v", err)
} else { } else {