diff --git a/bind/shared_bind.go b/bind/shared_bind.go index bff66bf..4a0e68d 100644 --- a/bind/shared_bind.go +++ b/bind/shared_bind.go @@ -9,12 +9,19 @@ import ( "runtime" "sync" "sync/atomic" + "time" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" wgConn "golang.zx2c4.com/wireguard/conn" ) +// injectedPacket represents a packet injected into the SharedBind from an internal source +type injectedPacket struct { + data []byte + endpoint wgConn.Endpoint +} + // Endpoint represents a network endpoint for the SharedBind type Endpoint struct { AddrPort netip.AddrPort @@ -71,6 +78,9 @@ type SharedBind struct { // Port binding information port uint16 + + // Channel for injected packets (from direct relay) + injectedPackets chan injectedPacket } // New creates a new SharedBind from an existing UDP connection. @@ -82,7 +92,8 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) { } bind := &SharedBind{ - udpConn: udpConn, + udpConn: udpConn, + injectedPackets: make(chan injectedPacket, 256), // Buffer for injected packets } // Initialize reference count to 1 (the creator holds the first reference) @@ -96,6 +107,30 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) { return bind, nil } +// InjectPacket allows injecting a packet directly into the SharedBind's receive path. +// This is used for direct relay from netstack without going through the host network. +// The fromAddr should be the address the packet appears to come from. +func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error { + if b.closed.Load() { + return net.ErrClosed + } + + // Make a copy of the data to avoid issues with buffer reuse + dataCopy := make([]byte, len(data)) + copy(dataCopy, data) + + select { + case b.injectedPackets <- injectedPacket{ + data: dataCopy, + endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr}, + }: + return nil + default: + // Channel full, drop the packet + return fmt.Errorf("injected packet buffer full") + } +} + // AddRef increments the reference count. Call this when sharing // the bind with another component. func (b *SharedBind) AddRef() { @@ -226,26 +261,54 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { // makeReceiveIPv4 creates a receive function for IPv4 packets func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - if b.closed.Load() { - return 0, net.ErrClosed + for { + if b.closed.Load() { + return 0, net.ErrClosed + } + + // Check for injected packets first (non-blocking) + select { + case pkt := <-b.injectedPackets: + if len(pkt.data) <= len(bufs[0]) { + copy(bufs[0], pkt.data) + sizes[0] = len(pkt.data) + eps[0] = pkt.endpoint + return 1, nil + } + default: + // No injected packets, continue to check socket + } + + b.mu.RLock() + conn := b.udpConn + pc := b.ipv4PC + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + // Set a short read deadline so we can poll for injected packets + conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) + + var n int + var err error + // Use batch reading on Linux for performance + if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { + n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps) + } else { + n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps) + } + + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + // Timeout - loop back to check for injected packets + continue + } + return n, err + } + return n, nil } - - b.mu.RLock() - conn := b.udpConn - pc := b.ipv4PC - b.mu.RUnlock() - - if conn == nil { - return 0, net.ErrClosed - } - - // Use batch reading on Linux for performance - if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - return b.receiveIPv4Batch(pc, bufs, sizes, eps) - } - - // Fallback to simple read for other platforms - return b.receiveIPv4Simple(conn, bufs, sizes, eps) } } diff --git a/clients.go b/clients.go index dd5afba..42f9187 100644 --- a/clients.go +++ b/clients.go @@ -1,14 +1,12 @@ package main import ( - "fmt" "strings" "github.com/fosrl/newt/clients" wgnetstack "github.com/fosrl/newt/clients" "github.com/fosrl/newt/logger" "github.com/fosrl/newt/netstack2" - "github.com/fosrl/newt/proxy" "github.com/fosrl/newt/websocket" "golang.zx2c4.com/wireguard/tun/netstack" @@ -106,13 +104,15 @@ func clientsOnConnect() { } } -func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { +// clientsStartDirectRelay starts a direct UDP relay from the main tunnel netstack +// to the clients' WireGuard, bypassing the proxy for better performance. +func clientsStartDirectRelay(tunnelIP string) { if !ready { return } - // add a udp proxy for localost and the wgService port - // TODO: make sure this port is not used in a target if wgService != nil { - pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) + if err := wgService.StartDirectUDPRelay(tunnelIP); err != nil { + logger.Error("Failed to start direct UDP relay: %v", err) + } } } diff --git a/clients/clients.go b/clients/clients.go index 2f4289c..82420f0 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -98,6 +98,9 @@ type WireGuardService struct { sharedBind *bind.SharedBind holePunchManager *holepunch.Manager useNativeInterface bool + // Direct UDP relay from main tunnel to clients' WireGuard + directRelayStop chan struct{} + directRelayWg sync.WaitGroup } func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { @@ -211,6 +214,9 @@ func (s *WireGuardService) Close() { s.stopGetConfig = nil } + // Stop the direct UDP relay first + s.StopDirectUDPRelay() + // Stop hole punch manager if s.holePunchManager != nil { s.holePunchManager.Stop() @@ -291,6 +297,114 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) { } } +// StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard. +// This bypasses the proxy by listening on the main tunnel's netstack and forwarding packets +// directly to the SharedBind that feeds the clients' WireGuard device. +// tunnelIP is the IP address to listen on within the main tunnel's netstack. +func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error { + if s.othertnet == nil { + return fmt.Errorf("main tunnel netstack (othertnet) not set") + } + if s.sharedBind == nil { + return fmt.Errorf("shared bind not initialized") + } + + // Stop any existing relay + s.StopDirectUDPRelay() + + s.directRelayStop = make(chan struct{}) + + // Parse the tunnel IP + ip := net.ParseIP(tunnelIP) + if ip == nil { + return fmt.Errorf("invalid tunnel IP: %s", tunnelIP) + } + + // Listen on the main tunnel netstack for UDP packets destined for the clients' WireGuard port + listenAddr := &net.UDPAddr{ + IP: ip, + Port: int(s.Port), + } + + // Use othertnet (main tunnel's netstack) to listen + listener, err := s.othertnet.ListenUDP(listenAddr) + if err != nil { + return fmt.Errorf("failed to listen on main tunnel netstack: %v", err) + } + + logger.Info("Started direct UDP relay on %s:%d (bypassing proxy)", tunnelIP, s.Port) + + // Start the relay goroutine + s.directRelayWg.Add(1) + go s.runDirectUDPRelay(listener) + + return nil +} + +// runDirectUDPRelay handles the UDP relay between the main tunnel netstack and the SharedBind +func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) { + defer s.directRelayWg.Done() + defer listener.Close() + + logger.Info("Direct UDP relay started (injecting directly into SharedBind)") + + buf := make([]byte, 65535) // Max UDP packet size + + for { + select { + case <-s.directRelayStop: + logger.Info("Stopping direct UDP relay") + return + default: + } + + // Set a read deadline so we can check for stop signal periodically + listener.SetReadDeadline(time.Now().Add(100 * time.Millisecond)) + + n, remoteAddr, err := listener.ReadFrom(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue // Just a timeout, check for stop and try again + } + if s.directRelayStop != nil { + select { + case <-s.directRelayStop: + return // Stopped + default: + } + } + logger.Debug("Direct UDP relay read error: %v", err) + continue + } + + // Get the source address + var srcAddrPort netip.AddrPort + if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok { + srcAddrPort = udpAddr.AddrPort() + } else { + logger.Debug("Unexpected address type in relay: %T", remoteAddr) + continue + } + + // Inject the packet directly into the SharedBind + if err := s.sharedBind.InjectPacket(buf[:n], srcAddrPort); err != nil { + logger.Debug("Failed to inject packet into SharedBind: %v", err) + continue + } + + logger.Debug("Injected %d bytes from %s into SharedBind", n, srcAddrPort.String()) + } +} + +// StopDirectUDPRelay stops the direct UDP relay +func (s *WireGuardService) StopDirectUDPRelay() { + if s.directRelayStop != nil { + close(s.directRelayStop) + s.directRelayWg.Wait() + s.directRelayStop = nil + } +} + func (s *WireGuardService) LoadRemoteConfig() error { if s.stopGetConfig != nil { s.stopGetConfig() diff --git a/main.go b/main.go index 2f7f9b3..a141141 100644 --- a/main.go +++ b/main.go @@ -742,7 +742,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey( // } } - clientsAddProxyTarget(pm, wgData.TunnelIP) + // Start direct UDP relay from main tunnel to clients' WireGuard (bypasses proxy) + clientsStartDirectRelay(wgData.TunnelIP) if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil { logger.Error("Failed to bulk add health check targets: %v", err) diff --git a/stub.go b/stub.go index 3bdbe19..e711da1 100644 --- a/stub.go +++ b/stub.go @@ -32,3 +32,8 @@ func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) { _ = tunnelIp // No-op for non-Linux systems } + +func clientsStartDirectRelayNative(tunnelIP string) { + _ = tunnelIP + // No-op for non-Linux systems +}