Maybe its working?

This commit is contained in:
Owen
2025-07-25 10:59:34 -07:00
parent 40dfab31a5
commit 499ebcd928
3 changed files with 111 additions and 15 deletions

View File

@@ -13,6 +13,7 @@ import (
// "github.com/fosrl/newt/wg"
"github.com/fosrl/newt/wgnetstack"
"github.com/fosrl/newt/wgtester"
"golang.zx2c4.com/wireguard/tun/netstack"
)
var wgService *wgnetstack.WireGuardService
@@ -40,6 +41,16 @@ func setupClients(client *websocket.Client) {
logger.Error("Failed to start WireGuard tester server: %v", err)
}
// Set up callback to restart wgtester with netstack when WireGuard is ready
wgService.SetOnNetstackReady(func(tnet *netstack.Net) {
logger.Info("WireGuard netstack is ready, restarting wgtester with netstack")
if err := wgTesterServer.RestartWithNetstack(tnet); err != nil {
logger.Error("Failed to restart wgtester with netstack: %v", err)
} else {
logger.Info("WGTester successfully restarted with netstack")
}
})
client.OnTokenUpdate(func(token string) {
wgService.SetToken(token)
})

View File

@@ -72,6 +72,8 @@ type WireGuardService struct {
tnet *netstack.Net
device *device.Device
dns []netip.Addr
// Callback for when netstack is ready
onNetstackReady func(*netstack.Net)
}
// Add this type definition
@@ -257,6 +259,11 @@ func (s *WireGuardService) GetPublicKey() wgtypes.Key {
return s.key.PublicKey()
}
// SetOnNetstackReady sets a callback function to be called when the netstack interface is ready
func (s *WireGuardService) SetOnNetstackReady(callback func(*netstack.Net)) {
s.onNetstackReady = callback
}
func (s *WireGuardService) LoadRemoteConfig() error {
s.stopGetConfig = s.client.SendMessageInterval("newt/wg/get-config", map[string]interface{}{
"publicKey": s.key.PublicKey().String(),
@@ -348,6 +355,12 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
}
logger.Info("WireGuard netstack device created and configured")
// Call the callback if it's set to notify that netstack is ready
if s.onNetstackReady != nil {
s.onNetstackReady(s.tnet)
}
return nil
}

View File

@@ -8,6 +8,8 @@ import (
"time"
"github.com/fosrl/newt/logger"
"golang.zx2c4.com/wireguard/tun/netstack"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
)
const (
@@ -26,7 +28,9 @@ const (
// Server handles listening for connection check requests using UDP
type Server struct {
conn *net.UDPConn
conn net.Conn // Generic net.Conn interface (could be *net.UDPConn or *gonet.UDPConn)
udpConn *net.UDPConn // Regular UDP connection (when not using netstack)
netstackConn interface{} // Netstack UDP connection (when using netstack)
serverAddr string
serverPort uint16
shutdownCh chan struct{}
@@ -34,6 +38,8 @@ type Server struct {
runningLock sync.Mutex
newtID string
outputPrefix string
useNetstack bool
tnet interface{} // Will be *netstack.Net when using netstack
}
// NewServer creates a new connection test server using UDP
@@ -44,6 +50,21 @@ func NewServer(serverAddr string, serverPort uint16, newtID string) *Server {
shutdownCh: make(chan struct{}),
newtID: newtID,
outputPrefix: "[WGTester] ",
useNetstack: false,
tnet: nil,
}
}
// NewServerWithNetstack creates a new connection test server using WireGuard netstack
func NewServerWithNetstack(serverAddr string, serverPort uint16, newtID string, tnet *netstack.Net) *Server {
return &Server{
serverAddr: serverAddr,
serverPort: serverPort + 1, // use the next port for the server
shutdownCh: make(chan struct{}),
newtID: newtID,
outputPrefix: "[WGTester] ",
useNetstack: true,
tnet: tnet,
}
}
@@ -59,18 +80,30 @@ func (s *Server) Start() error {
//create the address to listen on
addr := net.JoinHostPort(s.serverAddr, fmt.Sprintf("%d", s.serverPort))
// Create UDP address to listen on
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
if s.useNetstack && s.tnet != nil {
// Use WireGuard netstack
tnet := s.tnet.(*netstack.Net)
udpAddr := &net.UDPAddr{Port: int(s.serverPort)}
netstackConn, err := tnet.ListenUDP(udpAddr)
if err != nil {
return err
}
s.netstackConn = netstackConn
s.conn = netstackConn
} else {
// Use regular UDP socket
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return err
}
// Create UDP connection
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
udpConn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return err
}
s.udpConn = udpConn
s.conn = udpConn
}
s.conn = conn
s.isRunning = true
go s.handleConnections()
@@ -96,6 +129,26 @@ func (s *Server) Stop() {
logger.Info(s.outputPrefix + "Server stopped")
}
// RestartWithNetstack stops the current server and restarts it with netstack
func (s *Server) RestartWithNetstack(tnet *netstack.Net) error {
s.Stop()
// Update configuration to use netstack
s.useNetstack = true
s.tnet = tnet
// Clear previous connections
s.conn = nil
s.udpConn = nil
s.netstackConn = nil
// Create new shutdown channel
s.shutdownCh = make(chan struct{})
// Restart the server
return s.Start()
}
// handleConnections processes incoming packets
func (s *Server) handleConnections() {
buffer := make([]byte, 2000) // Buffer large enough for any UDP packet
@@ -112,8 +165,18 @@ func (s *Server) handleConnections() {
continue
}
// Read from UDP connection
n, addr, err := s.conn.ReadFromUDP(buffer)
// Read from UDP connection - handle both regular UDP and netstack UDP
var n int
var addr net.Addr
if s.useNetstack {
// Use netstack UDP connection
netstackConn := s.netstackConn.(*gonet.UDPConn)
n, addr, err = netstackConn.ReadFrom(buffer)
} else {
// Use regular UDP connection
n, addr, err = s.udpConn.ReadFromUDP(buffer)
}
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Just a timeout, keep going
@@ -158,8 +221,17 @@ func (s *Server) handleConnections() {
// Log response being sent for debugging
logger.Debug(s.outputPrefix+"Sending response to %s", addr.String())
// Send the response packet directly to the source address
_, err = s.conn.WriteToUDP(responsePacket, addr)
// Send the response packet - handle both regular UDP and netstack UDP
if s.useNetstack {
// Use netstack UDP connection
netstackConn := s.netstackConn.(*gonet.UDPConn)
_, err = netstackConn.WriteTo(responsePacket, addr)
} else {
// Use regular UDP connection
udpAddr := addr.(*net.UDPAddr)
_, err = s.udpConn.WriteToUDP(responsePacket, udpAddr)
}
if err != nil {
logger.Error(s.outputPrefix+"Error sending response: %v", err)
} else {