diff --git a/bind/shared_bind.go b/bind/shared_bind.go index 4a0e68d..d6d967c 100644 --- a/bind/shared_bind.go +++ b/bind/shared_bind.go @@ -16,6 +16,25 @@ import ( wgConn "golang.zx2c4.com/wireguard/conn" ) +// PacketSource identifies where a packet came from +type PacketSource uint8 + +const ( + SourceSocket PacketSource = iota // From physical UDP socket (hole-punched clients) + SourceNetstack // From netstack (relay through main tunnel) +) + +// SourceAwareEndpoint wraps an endpoint with source information +type SourceAwareEndpoint struct { + wgConn.Endpoint + source PacketSource +} + +// GetSource returns the source of this endpoint +func (e *SourceAwareEndpoint) GetSource() PacketSource { + return e.source +} + // injectedPacket represents a packet injected into the SharedBind from an internal source type injectedPacket struct { data []byte @@ -59,10 +78,12 @@ func (e *Endpoint) SrcToString() string { // SharedBind is a thread-safe UDP bind that can be shared between WireGuard // and hole punch senders. It wraps a single UDP connection and implements // reference counting to prevent premature closure. +// It also supports receiving packets from a netstack and routing responses +// back through the appropriate source. type SharedBind struct { mu sync.RWMutex - // The underlying UDP connection + // The underlying UDP connection (for hole-punched clients) udpConn *net.UDPConn // IPv4 and IPv6 packet connections for advanced features @@ -79,8 +100,15 @@ type SharedBind struct { // Port binding information port uint16 - // Channel for injected packets (from direct relay) - injectedPackets chan injectedPacket + // Channel for packets from netstack (from direct relay) + netstackPackets chan injectedPacket + + // Netstack connection for sending responses back through the tunnel + netstackConn net.PacketConn + netstackMu sync.RWMutex + + // Track which endpoints came from netstack (key: AddrPort string, value: true) + netstackEndpoints sync.Map } // New creates a new SharedBind from an existing UDP connection. @@ -93,7 +121,7 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) { bind := &SharedBind{ udpConn: udpConn, - injectedPackets: make(chan injectedPacket, 256), // Buffer for injected packets + netstackPackets: make(chan injectedPacket, 256), // Buffer for netstack packets } // Initialize reference count to 1 (the creator holds the first reference) @@ -107,6 +135,21 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) { return bind, nil } +// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel. +// This connection is used for relay traffic that should go back through the main tunnel. +func (b *SharedBind) SetNetstackConn(conn net.PacketConn) { + b.netstackMu.Lock() + defer b.netstackMu.Unlock() + b.netstackConn = conn +} + +// GetNetstackConn returns the netstack connection if set +func (b *SharedBind) GetNetstackConn() net.PacketConn { + b.netstackMu.RLock() + defer b.netstackMu.RUnlock() + return b.netstackConn +} + // InjectPacket allows injecting a packet directly into the SharedBind's receive path. // This is used for direct relay from netstack without going through the host network. // The fromAddr should be the address the packet appears to come from. @@ -115,19 +158,22 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error { return net.ErrClosed } + // Track this endpoint as coming from netstack so responses go back the same way + b.netstackEndpoints.Store(fromAddr.String(), true) + // Make a copy of the data to avoid issues with buffer reuse dataCopy := make([]byte, len(data)) copy(dataCopy, data) select { - case b.injectedPackets <- injectedPacket{ + case b.netstackPackets <- injectedPacket{ data: dataCopy, endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr}, }: return nil default: // Channel full, drop the packet - return fmt.Errorf("injected packet buffer full") + return fmt.Errorf("netstack packet buffer full") } } @@ -178,9 +224,28 @@ func (b *SharedBind) closeConnection() error { b.ipv4PC = nil b.ipv6PC = nil + // Clear netstack connection (but don't close it - it's managed externally) + b.netstackMu.Lock() + b.netstackConn = nil + b.netstackMu.Unlock() + + // Clear tracked netstack endpoints + b.netstackEndpoints = sync.Map{} + return err } +// ClearNetstackConn clears the netstack connection and tracked endpoints. +// Call this when stopping the relay. +func (b *SharedBind) ClearNetstackConn() { + b.netstackMu.Lock() + b.netstackConn = nil + b.netstackMu.Unlock() + + // Clear tracked netstack endpoints + b.netstackEndpoints = sync.Map{} +} + // GetUDPConn returns the underlying UDP connection. // The caller must not close this connection directly. func (b *SharedBind) GetUDPConn() *net.UDPConn { @@ -266,9 +331,9 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { return 0, net.ErrClosed } - // Check for injected packets first (non-blocking) + // Check for netstack packets first (non-blocking) select { - case pkt := <-b.injectedPackets: + case pkt := <-b.netstackPackets: if len(pkt.data) <= len(bufs[0]) { copy(bufs[0], pkt.data) sizes[0] = len(pkt.data) @@ -276,7 +341,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { return 1, nil } default: - // No injected packets, continue to check socket + // No netstack packets, continue to check socket } b.mu.RLock() @@ -288,7 +353,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { return 0, net.ErrClosed } - // Set a short read deadline so we can poll for injected packets + // Set a short read deadline so we can poll for netstack packets conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) var n int @@ -302,7 +367,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - // Timeout - loop back to check for injected packets + // Timeout - loop back to check for netstack packets continue } return n, err @@ -360,26 +425,19 @@ func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes [ } // Send implements the WireGuard Bind interface. -// It sends packets to the specified endpoint. +// It sends packets to the specified endpoint, routing through the appropriate +// source (netstack or physical socket) based on where the endpoint's packets came from. func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { if b.closed.Load() { return net.ErrClosed } - b.mu.RLock() - conn := b.udpConn - b.mu.RUnlock() - - if conn == nil { - return net.ErrClosed - } - // Extract the destination address from the endpoint - var destAddr *net.UDPAddr + var destAddrPort netip.AddrPort // Try to cast to StdNetEndpoint first if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok { - destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort) + destAddrPort = stdEp.AddrPort } else { // Fallback: construct from DstIP and DstToBytes dstBytes := ep.DstToBytes() @@ -396,15 +454,46 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { } if addr.IsValid() { - destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port)) + destAddrPort = netip.AddrPortFrom(addr, port) } } } - if destAddr == nil { + if !destAddrPort.IsValid() { return fmt.Errorf("could not extract destination address from endpoint") } + // Check if this endpoint came from netstack - if so, send through netstack + if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort.String()); isNetstackEndpoint { + b.netstackMu.RLock() + netstackConn := b.netstackConn + b.netstackMu.RUnlock() + + if netstackConn != nil { + destAddr := net.UDPAddrFromAddrPort(destAddrPort) + // Send all buffers through netstack + for _, buf := range bufs { + _, err := netstackConn.WriteTo(buf, destAddr) + if err != nil { + return err + } + } + return nil + } + // Fall through to socket if netstack conn not available + } + + // Send through the physical UDP socket (for hole-punched clients) + b.mu.RLock() + conn := b.udpConn + b.mu.RUnlock() + + if conn == nil { + return net.ErrClosed + } + + destAddr := net.UDPAddrFromAddrPort(destAddrPort) + // Send all buffers to the destination for _, buf := range bufs { _, err := conn.WriteToUDP(buf, destAddr) diff --git a/bind/shared_bind_test.go b/bind/shared_bind_test.go index 6e1ec66..0d63e7a 100644 --- a/bind/shared_bind_test.go +++ b/bind/shared_bind_test.go @@ -422,3 +422,184 @@ func TestParseEndpoint(t *testing.T) { }) } } + +// TestNetstackRouting tests that packets from netstack endpoints are routed back through netstack +func TestNetstackRouting(t *testing.T) { + // Create the SharedBind with a physical UDP socket + physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create physical UDP connection: %v", err) + } + + sharedBind, err := New(physicalConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer sharedBind.Close() + + // Create a mock "netstack" connection (just another UDP socket for testing) + netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create netstack UDP connection: %v", err) + } + defer netstackConn.Close() + + // Set the netstack connection + sharedBind.SetNetstackConn(netstackConn) + + // Create a "client" that would receive packets + clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create client UDP connection: %v", err) + } + defer clientConn.Close() + + clientAddr := clientConn.LocalAddr().(*net.UDPAddr) + clientAddrPort := clientAddr.AddrPort() + + // Inject a packet from the "netstack" source - this should track the endpoint + testData := []byte("test packet from netstack") + err = sharedBind.InjectPacket(testData, clientAddrPort) + if err != nil { + t.Fatalf("InjectPacket failed: %v", err) + } + + // Now when we send a response to this endpoint, it should go through netstack + endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort} + responseData := []byte("response packet") + err = sharedBind.Send([][]byte{responseData}, endpoint) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + // The packet should be received by the client from the netstack connection + buf := make([]byte, 1024) + clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, fromAddr, err := clientConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive response: %v", err) + } + + if string(buf[:n]) != string(responseData) { + t.Errorf("Expected to receive %q, got %q", responseData, buf[:n]) + } + + // Verify the response came from the netstack connection, not the physical one + netstackAddr := netstackConn.LocalAddr().(*net.UDPAddr) + if fromAddr.Port != netstackAddr.Port { + t.Errorf("Expected response from netstack port %d, got %d", netstackAddr.Port, fromAddr.Port) + } +} + +// TestSocketRouting tests that packets from socket endpoints are routed through socket +func TestSocketRouting(t *testing.T) { + // Create the SharedBind with a physical UDP socket + physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create physical UDP connection: %v", err) + } + + sharedBind, err := New(physicalConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer sharedBind.Close() + + // Create a mock "netstack" connection + netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create netstack UDP connection: %v", err) + } + defer netstackConn.Close() + + // Set the netstack connection + sharedBind.SetNetstackConn(netstackConn) + + // Create a "client" that would receive packets (this simulates a hole-punched client) + clientConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create client UDP connection: %v", err) + } + defer clientConn.Close() + + clientAddr := clientConn.LocalAddr().(*net.UDPAddr) + clientAddrPort := clientAddr.AddrPort() + + // Don't inject from netstack - this endpoint is NOT tracked as netstack-sourced + // So Send should use the physical socket + + endpoint := &wgConn.StdNetEndpoint{AddrPort: clientAddrPort} + responseData := []byte("response packet via socket") + err = sharedBind.Send([][]byte{responseData}, endpoint) + if err != nil { + t.Fatalf("Send failed: %v", err) + } + + // The packet should be received by the client from the physical connection + buf := make([]byte, 1024) + clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + n, fromAddr, err := clientConn.ReadFromUDP(buf) + if err != nil { + t.Fatalf("Failed to receive response: %v", err) + } + + if string(buf[:n]) != string(responseData) { + t.Errorf("Expected to receive %q, got %q", responseData, buf[:n]) + } + + // Verify the response came from the physical connection, not the netstack one + physicalAddr := physicalConn.LocalAddr().(*net.UDPAddr) + if fromAddr.Port != physicalAddr.Port { + t.Errorf("Expected response from physical port %d, got %d", physicalAddr.Port, fromAddr.Port) + } +} + +// TestClearNetstackConn tests that clearing the netstack connection works correctly +func TestClearNetstackConn(t *testing.T) { + physicalConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create physical UDP connection: %v", err) + } + + sharedBind, err := New(physicalConn) + if err != nil { + t.Fatalf("Failed to create SharedBind: %v", err) + } + defer sharedBind.Close() + + // Set a netstack connection + netstackConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) + if err != nil { + t.Fatalf("Failed to create netstack UDP connection: %v", err) + } + defer netstackConn.Close() + + sharedBind.SetNetstackConn(netstackConn) + + // Inject a packet to track an endpoint + testAddrPort := netip.MustParseAddrPort("192.168.1.100:51820") + err = sharedBind.InjectPacket([]byte("test"), testAddrPort) + if err != nil { + t.Fatalf("InjectPacket failed: %v", err) + } + + // Verify the endpoint is tracked + _, tracked := sharedBind.netstackEndpoints.Load(testAddrPort.String()) + if !tracked { + t.Error("Expected endpoint to be tracked as netstack-sourced") + } + + // Clear the netstack connection + sharedBind.ClearNetstackConn() + + // Verify the netstack connection is cleared + if sharedBind.GetNetstackConn() != nil { + t.Error("Expected netstack connection to be nil after clear") + } + + // Verify the tracked endpoints are cleared + _, stillTracked := sharedBind.netstackEndpoints.Load(testAddrPort.String()) + if stillTracked { + t.Error("Expected endpoint tracking to be cleared") + } +} diff --git a/clients/clients.go b/clients/clients.go index 82420f0..68fb780 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -99,8 +99,10 @@ type WireGuardService struct { holePunchManager *holepunch.Manager useNativeInterface bool // Direct UDP relay from main tunnel to clients' WireGuard - directRelayStop chan struct{} - directRelayWg sync.WaitGroup + directRelayStop chan struct{} + directRelayWg sync.WaitGroup + netstackListener net.PacketConn + netstackListenerMu sync.Mutex } func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { @@ -300,6 +302,7 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) { // StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard. // This bypasses the proxy by listening on the main tunnel's netstack and forwarding packets // directly to the SharedBind that feeds the clients' WireGuard device. +// Responses are automatically routed back through the netstack by the SharedBind. // tunnelIP is the IP address to listen on within the main tunnel's netstack. func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error { if s.othertnet == nil { @@ -332,21 +335,33 @@ func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error { return fmt.Errorf("failed to listen on main tunnel netstack: %v", err) } - logger.Info("Started direct UDP relay on %s:%d (bypassing proxy)", tunnelIP, s.Port) + // Store the listener reference so we can close it later + s.netstackListenerMu.Lock() + s.netstackListener = listener + s.netstackListenerMu.Unlock() - // Start the relay goroutine + // Set the netstack connection on the SharedBind so responses go back through the tunnel + s.sharedBind.SetNetstackConn(listener) + + logger.Info("Started direct UDP relay on %s:%d (bidirectional via SharedBind)", tunnelIP, s.Port) + + // Start the relay goroutine to read from netstack and inject into SharedBind s.directRelayWg.Add(1) go s.runDirectUDPRelay(listener) return nil } -// runDirectUDPRelay handles the UDP relay between the main tunnel netstack and the SharedBind +// runDirectUDPRelay handles receiving UDP packets from the main tunnel netstack +// and injecting them into the SharedBind for processing by WireGuard. +// Responses are handled automatically by SharedBind.Send() which routes them +// back through the netstack connection. func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) { defer s.directRelayWg.Done() - defer listener.Close() + // Note: Don't close listener here - it's also used by SharedBind for sending responses + // It will be closed when the relay is stopped - logger.Info("Direct UDP relay started (injecting directly into SharedBind)") + logger.Info("Direct UDP relay started (bidirectional through SharedBind)") buf := make([]byte, 65535) // Max UDP packet size @@ -386,23 +401,36 @@ func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) { continue } - // Inject the packet directly into the SharedBind + // Inject the packet directly into the SharedBind (also tracks this endpoint as netstack-sourced) if err := s.sharedBind.InjectPacket(buf[:n], srcAddrPort); err != nil { logger.Debug("Failed to inject packet into SharedBind: %v", err) continue } - logger.Debug("Injected %d bytes from %s into SharedBind", n, srcAddrPort.String()) + logger.Debug("Relayed %d bytes from %s into WireGuard", n, srcAddrPort.String()) } } -// StopDirectUDPRelay stops the direct UDP relay +// StopDirectUDPRelay stops the direct UDP relay and closes the netstack listener func (s *WireGuardService) StopDirectUDPRelay() { if s.directRelayStop != nil { close(s.directRelayStop) s.directRelayWg.Wait() s.directRelayStop = nil } + + // Clear the netstack connection from SharedBind so responses don't try to use it + if s.sharedBind != nil { + s.sharedBind.ClearNetstackConn() + } + + // Close the netstack listener + s.netstackListenerMu.Lock() + if s.netstackListener != nil { + s.netstackListener.Close() + s.netstackListener = nil + } + s.netstackListenerMu.Unlock() } func (s *WireGuardService) LoadRemoteConfig() error {