Working but no wgtester? - revert if bad

This commit is contained in:
Owen
2025-11-29 17:38:34 -05:00
parent 5196effdb8
commit de96be810b
3 changed files with 332 additions and 34 deletions

View File

@@ -16,6 +16,25 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
)
// PacketSource identifies where a packet came from
type PacketSource uint8
const (
SourceSocket PacketSource = iota // From physical UDP socket (hole-punched clients)
SourceNetstack // From netstack (relay through main tunnel)
)
// SourceAwareEndpoint wraps an endpoint with source information
type SourceAwareEndpoint struct {
wgConn.Endpoint
source PacketSource
}
// GetSource returns the source of this endpoint
func (e *SourceAwareEndpoint) GetSource() PacketSource {
return e.source
}
// injectedPacket represents a packet injected into the SharedBind from an internal source
type injectedPacket struct {
data []byte
@@ -59,10 +78,12 @@ func (e *Endpoint) SrcToString() string {
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
// and hole punch senders. It wraps a single UDP connection and implements
// reference counting to prevent premature closure.
// It also supports receiving packets from a netstack and routing responses
// back through the appropriate source.
type SharedBind struct {
mu sync.RWMutex
// The underlying UDP connection
// The underlying UDP connection (for hole-punched clients)
udpConn *net.UDPConn
// IPv4 and IPv6 packet connections for advanced features
@@ -79,8 +100,15 @@ type SharedBind struct {
// Port binding information
port uint16
// Channel for injected packets (from direct relay)
injectedPackets chan injectedPacket
// Channel for packets from netstack (from direct relay)
netstackPackets chan injectedPacket
// Netstack connection for sending responses back through the tunnel
netstackConn net.PacketConn
netstackMu sync.RWMutex
// Track which endpoints came from netstack (key: AddrPort string, value: true)
netstackEndpoints sync.Map
}
// New creates a new SharedBind from an existing UDP connection.
@@ -93,7 +121,7 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
bind := &SharedBind{
udpConn: udpConn,
injectedPackets: make(chan injectedPacket, 256), // Buffer for injected packets
netstackPackets: make(chan injectedPacket, 256), // Buffer for netstack packets
}
// Initialize reference count to 1 (the creator holds the first reference)
@@ -107,6 +135,21 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
return bind, nil
}
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
// This connection is used for relay traffic that should go back through the main tunnel.
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
b.netstackMu.Lock()
defer b.netstackMu.Unlock()
b.netstackConn = conn
}
// GetNetstackConn returns the netstack connection if set
func (b *SharedBind) GetNetstackConn() net.PacketConn {
b.netstackMu.RLock()
defer b.netstackMu.RUnlock()
return b.netstackConn
}
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
// This is used for direct relay from netstack without going through the host network.
// The fromAddr should be the address the packet appears to come from.
@@ -115,19 +158,22 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
return net.ErrClosed
}
// Track this endpoint as coming from netstack so responses go back the same way
b.netstackEndpoints.Store(fromAddr.String(), true)
// Make a copy of the data to avoid issues with buffer reuse
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
select {
case b.injectedPackets <- injectedPacket{
case b.netstackPackets <- injectedPacket{
data: dataCopy,
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
}:
return nil
default:
// Channel full, drop the packet
return fmt.Errorf("injected packet buffer full")
return fmt.Errorf("netstack packet buffer full")
}
}
@@ -178,9 +224,28 @@ func (b *SharedBind) closeConnection() error {
b.ipv4PC = nil
b.ipv6PC = nil
// Clear netstack connection (but don't close it - it's managed externally)
b.netstackMu.Lock()
b.netstackConn = nil
b.netstackMu.Unlock()
// Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{}
return err
}
// ClearNetstackConn clears the netstack connection and tracked endpoints.
// Call this when stopping the relay.
func (b *SharedBind) ClearNetstackConn() {
b.netstackMu.Lock()
b.netstackConn = nil
b.netstackMu.Unlock()
// Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{}
}
// GetUDPConn returns the underlying UDP connection.
// The caller must not close this connection directly.
func (b *SharedBind) GetUDPConn() *net.UDPConn {
@@ -266,9 +331,9 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return 0, net.ErrClosed
}
// Check for injected packets first (non-blocking)
// Check for netstack packets first (non-blocking)
select {
case pkt := <-b.injectedPackets:
case pkt := <-b.netstackPackets:
if len(pkt.data) <= len(bufs[0]) {
copy(bufs[0], pkt.data)
sizes[0] = len(pkt.data)
@@ -276,7 +341,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return 1, nil
}
default:
// No injected packets, continue to check socket
// No netstack packets, continue to check socket
}
b.mu.RLock()
@@ -288,7 +353,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return 0, net.ErrClosed
}
// Set a short read deadline so we can poll for injected packets
// Set a short read deadline so we can poll for netstack packets
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
var n int
@@ -302,7 +367,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout - loop back to check for injected packets
// Timeout - loop back to check for netstack packets
continue
}
return n, err
@@ -360,26 +425,19 @@ func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes [
}
// Send implements the WireGuard Bind interface.
// It sends packets to the specified endpoint.
// It sends packets to the specified endpoint, routing through the appropriate
// source (netstack or physical socket) based on where the endpoint's packets came from.
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
if b.closed.Load() {
return net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
// Extract the destination address from the endpoint
var destAddr *net.UDPAddr
var destAddrPort netip.AddrPort
// Try to cast to StdNetEndpoint first
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
destAddrPort = stdEp.AddrPort
} else {
// Fallback: construct from DstIP and DstToBytes
dstBytes := ep.DstToBytes()
@@ -396,15 +454,46 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
}
if addr.IsValid() {
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
destAddrPort = netip.AddrPortFrom(addr, port)
}
}
}
if destAddr == nil {
if !destAddrPort.IsValid() {
return fmt.Errorf("could not extract destination address from endpoint")
}
// Check if this endpoint came from netstack - if so, send through netstack
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort.String()); isNetstackEndpoint {
b.netstackMu.RLock()
netstackConn := b.netstackConn
b.netstackMu.RUnlock()
if netstackConn != nil {
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
// Send all buffers through netstack
for _, buf := range bufs {
_, err := netstackConn.WriteTo(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// Fall through to socket if netstack conn not available
}
// Send through the physical UDP socket (for hole-punched clients)
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
// Send all buffers to the destination
for _, buf := range bufs {
_, err := conn.WriteToUDP(buf, destAddr)

View File

@@ -422,3 +422,184 @@ func TestParseEndpoint(t *testing.T) {
})
}
}
// TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack
func TestNetstackRouting(t *testing.T) {
// Create the SharedBind with a physical UDP socket
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
sharedBind, err := New(physicalConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer sharedBind.Close()
// Create a mock "netstack" connection (just another UDP socket for testing)
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
defer netstackConn.Close()
// Set the netstack connection
sharedBind.SetNetstackConn(netstackConn)
// Create a "client" that would receive packets
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create client UDP connection: %v", err)
}
defer clientConn.Close()
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
clientAddrPort := clientAddr.AddrPort()
// Inject a packet from the "netstack" source - this should track the endpoint
testData := []byte("test packet from netstack")
err = sharedBind.InjectPacket(testData, clientAddrPort)
if err != nil {
t.Fatalf("InjectPacket failed: %v", err)
}
// Now when we send a response to this endpoint, it should go through netstack
endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort}
responseData := []byte("response packet")
err = sharedBind.Send([][]byte{responseData}, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// The packet should be received by the client from the netstack connection
buf := make([]byte, 1024)
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, fromAddr, err := clientConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive response: %v", err)
}
if string(buf[:n]) != string(responseData) {
t.Errorf("Expected to receive %q, got %q", responseData, buf[:n])
}
// Verify the response came from the netstack connection, not the physical one
netstackAddr := netstackConn.LocalAddr().(*net.UDPAddr)
if fromAddr.Port != netstackAddr.Port {
t.Errorf("Expected response from netstack port %d, got %d", netstackAddr.Port, fromAddr.Port)
}
}
// TestSocketRouting tests that packets from socket endpoints are routed through socket
func TestSocketRouting(t *testing.T) {
// Create the SharedBind with a physical UDP socket
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
sharedBind, err := New(physicalConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer sharedBind.Close()
// Create a mock "netstack" connection
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
defer netstackConn.Close()
// Set the netstack connection
sharedBind.SetNetstackConn(netstackConn)
// Create a "client" that would receive packets (this simulates a hole-punched client)
clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create client UDP connection: %v", err)
}
defer clientConn.Close()
clientAddr := clientConn.LocalAddr().(*net.UDPAddr)
clientAddrPort := clientAddr.AddrPort()
// Don't inject from netstack - this endpoint is NOT tracked as netstack-sourced
// So Send should use the physical socket
endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort}
responseData := []byte("response packet via socket")
err = sharedBind.Send([][]byte{responseData}, endpoint)
if err != nil {
t.Fatalf("Send failed: %v", err)
}
// The packet should be received by the client from the physical connection
buf := make([]byte, 1024)
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
n, fromAddr, err := clientConn.ReadFromUDP(buf)
if err != nil {
t.Fatalf("Failed to receive response: %v", err)
}
if string(buf[:n]) != string(responseData) {
t.Errorf("Expected to receive %q, got %q", responseData, buf[:n])
}
// Verify the response came from the physical connection, not the netstack one
physicalAddr := physicalConn.LocalAddr().(*net.UDPAddr)
if fromAddr.Port != physicalAddr.Port {
t.Errorf("Expected response from physical port %d, got %d", physicalAddr.Port, fromAddr.Port)
}
}
// TestClearNetstackConn tests that clearing the netstack connection works correctly
func TestClearNetstackConn(t *testing.T) {
physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create physical UDP connection: %v", err)
}
sharedBind, err := New(physicalConn)
if err != nil {
t.Fatalf("Failed to create SharedBind: %v", err)
}
defer sharedBind.Close()
// Set a netstack connection
netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
if err != nil {
t.Fatalf("Failed to create netstack UDP connection: %v", err)
}
defer netstackConn.Close()
sharedBind.SetNetstackConn(netstackConn)
// Inject a packet to track an endpoint
testAddrPort := netip.MustParseAddrPort("192.168.1.100:51820")
err = sharedBind.InjectPacket([]byte("test"), testAddrPort)
if err != nil {
t.Fatalf("InjectPacket failed: %v", err)
}
// Verify the endpoint is tracked
_, tracked := sharedBind.netstackEndpoints.Load(testAddrPort.String())
if !tracked {
t.Error("Expected endpoint to be tracked as netstack-sourced")
}
// Clear the netstack connection
sharedBind.ClearNetstackConn()
// Verify the netstack connection is cleared
if sharedBind.GetNetstackConn() != nil {
t.Error("Expected netstack connection to be nil after clear")
}
// Verify the tracked endpoints are cleared
_, stillTracked := sharedBind.netstackEndpoints.Load(testAddrPort.String())
if stillTracked {
t.Error("Expected endpoint tracking to be cleared")
}
}

View File

@@ -99,8 +99,10 @@ type WireGuardService struct {
holePunchManager *holepunch.Manager
useNativeInterface bool
// Direct UDP relay from main tunnel to clients' WireGuard
directRelayStop chan struct{}
directRelayWg sync.WaitGroup
directRelayStop chan struct{}
directRelayWg sync.WaitGroup
netstackListener net.PacketConn
netstackListenerMu sync.Mutex
}
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
@@ -300,6 +302,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
// StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard.
// This bypasses the proxy by listening on the main tunnel's netstack and forwarding packets
// directly to the SharedBind that feeds the clients' WireGuard device.
// Responses are automatically routed back through the netstack by the SharedBind.
// tunnelIP is the IP address to listen on within the main tunnel's netstack.
func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error {
if s.othertnet == nil {
@@ -332,21 +335,33 @@ func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error {
return fmt.Errorf("failed to listen on main tunnel netstack: %v", err)
}
logger.Info("Started direct UDP relay on %s:%d (bypassing proxy)", tunnelIP, s.Port)
// Store the listener reference so we can close it later
s.netstackListenerMu.Lock()
s.netstackListener = listener
s.netstackListenerMu.Unlock()
// Start the relay goroutine
// Set the netstack connection on the SharedBind so responses go back through the tunnel
s.sharedBind.SetNetstackConn(listener)
logger.Info("Started direct UDP relay on %s:%d (bidirectional via SharedBind)", tunnelIP, s.Port)
// Start the relay goroutine to read from netstack and inject into SharedBind
s.directRelayWg.Add(1)
go s.runDirectUDPRelay(listener)
return nil
}
// runDirectUDPRelay handles the UDP relay between the main tunnel netstack and the SharedBind
// runDirectUDPRelay handles receiving UDP packets from the main tunnel netstack
// and injecting them into the SharedBind for processing by WireGuard.
// Responses are handled automatically by SharedBind.Send() which routes them
// back through the netstack connection.
func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) {
defer s.directRelayWg.Done()
defer listener.Close()
// Note: Don't close listener here - it's also used by SharedBind for sending responses
// It will be closed when the relay is stopped
logger.Info("Direct UDP relay started (injecting directly into SharedBind)")
logger.Info("Direct UDP relay started (bidirectional through SharedBind)")
buf := make([]byte, 65535) // Max UDP packet size
@@ -386,23 +401,36 @@ func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) {
continue
}
// Inject the packet directly into the SharedBind
// Inject the packet directly into the SharedBind (also tracks this endpoint as netstack-sourced)
if err := s.sharedBind.InjectPacket(buf[:n], srcAddrPort); err != nil {
logger.Debug("Failed to inject packet into SharedBind: %v", err)
continue
}
logger.Debug("Injected %d bytes from %s into SharedBind", n, srcAddrPort.String())
logger.Debug("Relayed %d bytes from %s into WireGuard", n, srcAddrPort.String())
}
}
// StopDirectUDPRelay stops the direct UDP relay
// StopDirectUDPRelay stops the direct UDP relay and closes the netstack listener
func (s *WireGuardService) StopDirectUDPRelay() {
if s.directRelayStop != nil {
close(s.directRelayStop)
s.directRelayWg.Wait()
s.directRelayStop = nil
}
// Clear the netstack connection from SharedBind so responses don't try to use it
if s.sharedBind != nil {
s.sharedBind.ClearNetstackConn()
}
// Close the netstack listener
s.netstackListenerMu.Lock()
if s.netstackListener != nil {
s.netstackListener.Close()
s.netstackListener = nil
}
s.netstackListenerMu.Unlock()
}
func (s *WireGuardService) LoadRemoteConfig() error {