Kind of working - revert if not

This commit is contained in:
Owen
2025-11-26 17:57:27 -05:00
parent d6edd6ca01
commit 5196effdb8
5 changed files with 210 additions and 27 deletions

View File

@@ -9,12 +9,19 @@ import (
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn" wgConn "golang.zx2c4.com/wireguard/conn"
) )
// injectedPacket represents a packet injected into the SharedBind from an internal source
type injectedPacket struct {
data []byte
endpoint wgConn.Endpoint
}
// Endpoint represents a network endpoint for the SharedBind // Endpoint represents a network endpoint for the SharedBind
type Endpoint struct { type Endpoint struct {
AddrPort netip.AddrPort AddrPort netip.AddrPort
@@ -71,6 +78,9 @@ type SharedBind struct {
// Port binding information // Port binding information
port uint16 port uint16
// Channel for injected packets (from direct relay)
injectedPackets chan injectedPacket
} }
// New creates a new SharedBind from an existing UDP connection. // New creates a new SharedBind from an existing UDP connection.
@@ -83,6 +93,7 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
bind := &SharedBind{ bind := &SharedBind{
udpConn: udpConn, udpConn: udpConn,
injectedPackets: make(chan injectedPacket, 256), // Buffer for injected packets
} }
// Initialize reference count to 1 (the creator holds the first reference) // Initialize reference count to 1 (the creator holds the first reference)
@@ -96,6 +107,30 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
return bind, nil return bind, nil
} }
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
// This is used for direct relay from netstack without going through the host network.
// The fromAddr should be the address the packet appears to come from.
func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
if b.closed.Load() {
return net.ErrClosed
}
// Make a copy of the data to avoid issues with buffer reuse
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
select {
case b.injectedPackets <- injectedPacket{
data: dataCopy,
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
}:
return nil
default:
// Channel full, drop the packet
return fmt.Errorf("injected packet buffer full")
}
}
// AddRef increments the reference count. Call this when sharing // AddRef increments the reference count. Call this when sharing
// the bind with another component. // the bind with another component.
func (b *SharedBind) AddRef() { func (b *SharedBind) AddRef() {
@@ -226,10 +261,24 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
// makeReceiveIPv4 creates a receive function for IPv4 packets // makeReceiveIPv4 creates a receive function for IPv4 packets
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
for {
if b.closed.Load() { if b.closed.Load() {
return 0, net.ErrClosed return 0, net.ErrClosed
} }
// Check for injected packets first (non-blocking)
select {
case pkt := <-b.injectedPackets:
if len(pkt.data) <= len(bufs[0]) {
copy(bufs[0], pkt.data)
sizes[0] = len(pkt.data)
eps[0] = pkt.endpoint
return 1, nil
}
default:
// No injected packets, continue to check socket
}
b.mu.RLock() b.mu.RLock()
conn := b.udpConn conn := b.udpConn
pc := b.ipv4PC pc := b.ipv4PC
@@ -239,13 +288,27 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return 0, net.ErrClosed return 0, net.ErrClosed
} }
// Set a short read deadline so we can poll for injected packets
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
var n int
var err error
// Use batch reading on Linux for performance // Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps) n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps)
} else {
n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps)
} }
// Fallback to simple read for other platforms if err != nil {
return b.receiveIPv4Simple(conn, bufs, sizes, eps) if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout - loop back to check for injected packets
continue
}
return n, err
}
return n, nil
}
} }
} }

View File

@@ -1,14 +1,12 @@
package main package main
import ( import (
"fmt"
"strings" "strings"
"github.com/fosrl/newt/clients" "github.com/fosrl/newt/clients"
wgnetstack "github.com/fosrl/newt/clients" wgnetstack "github.com/fosrl/newt/clients"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"github.com/fosrl/newt/netstack2" "github.com/fosrl/newt/netstack2"
"github.com/fosrl/newt/proxy"
"github.com/fosrl/newt/websocket" "github.com/fosrl/newt/websocket"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
@@ -106,13 +104,15 @@ func clientsOnConnect() {
} }
} }
func clientsAddProxyTarget(pm *proxy.ProxyManager, tunnelIp string) { // clientsStartDirectRelay starts a direct UDP relay from the main tunnel netstack
// to the clients' WireGuard, bypassing the proxy for better performance.
func clientsStartDirectRelay(tunnelIP string) {
if !ready { if !ready {
return return
} }
// add a udp proxy for localost and the wgService port
// TODO: make sure this port is not used in a target
if wgService != nil { if wgService != nil {
pm.AddTarget("udp", tunnelIp, int(wgService.Port), fmt.Sprintf("127.0.0.1:%d", wgService.Port)) if err := wgService.StartDirectUDPRelay(tunnelIP); err != nil {
logger.Error("Failed to start direct UDP relay: %v", err)
}
} }
} }

View File

@@ -98,6 +98,9 @@ type WireGuardService struct {
sharedBind *bind.SharedBind sharedBind *bind.SharedBind
holePunchManager *holepunch.Manager holePunchManager *holepunch.Manager
useNativeInterface bool useNativeInterface bool
// Direct UDP relay from main tunnel to clients' WireGuard
directRelayStop chan struct{}
directRelayWg sync.WaitGroup
} }
func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) { func NewWireGuardService(interfaceName string, mtu int, generateAndSaveKeyTo string, host string, newtId string, wsClient *websocket.Client, dns string, useNativeInterface bool) (*WireGuardService, error) {
@@ -211,6 +214,9 @@ func (s *WireGuardService) Close() {
s.stopGetConfig = nil s.stopGetConfig = nil
} }
// Stop the direct UDP relay first
s.StopDirectUDPRelay()
// Stop hole punch manager // Stop hole punch manager
if s.holePunchManager != nil { if s.holePunchManager != nil {
s.holePunchManager.Stop() s.holePunchManager.Stop()
@@ -291,6 +297,114 @@ func (s *WireGuardService) StartHolepunch(publicKey string, endpoint string) {
} }
} }
// StartDirectUDPRelay starts a direct UDP relay from the main tunnel netstack to the clients' WireGuard.
// This bypasses the proxy by listening on the main tunnel's netstack and forwarding packets
// directly to the SharedBind that feeds the clients' WireGuard device.
// tunnelIP is the IP address to listen on within the main tunnel's netstack.
func (s *WireGuardService) StartDirectUDPRelay(tunnelIP string) error {
if s.othertnet == nil {
return fmt.Errorf("main tunnel netstack (othertnet) not set")
}
if s.sharedBind == nil {
return fmt.Errorf("shared bind not initialized")
}
// Stop any existing relay
s.StopDirectUDPRelay()
s.directRelayStop = make(chan struct{})
// Parse the tunnel IP
ip := net.ParseIP(tunnelIP)
if ip == nil {
return fmt.Errorf("invalid tunnel IP: %s", tunnelIP)
}
// Listen on the main tunnel netstack for UDP packets destined for the clients' WireGuard port
listenAddr := &net.UDPAddr{
IP: ip,
Port: int(s.Port),
}
// Use othertnet (main tunnel's netstack) to listen
listener, err := s.othertnet.ListenUDP(listenAddr)
if err != nil {
return fmt.Errorf("failed to listen on main tunnel netstack: %v", err)
}
logger.Info("Started direct UDP relay on %s:%d (bypassing proxy)", tunnelIP, s.Port)
// Start the relay goroutine
s.directRelayWg.Add(1)
go s.runDirectUDPRelay(listener)
return nil
}
// runDirectUDPRelay handles the UDP relay between the main tunnel netstack and the SharedBind
func (s *WireGuardService) runDirectUDPRelay(listener net.PacketConn) {
defer s.directRelayWg.Done()
defer listener.Close()
logger.Info("Direct UDP relay started (injecting directly into SharedBind)")
buf := make([]byte, 65535) // Max UDP packet size
for {
select {
case <-s.directRelayStop:
logger.Info("Stopping direct UDP relay")
return
default:
}
// Set a read deadline so we can check for stop signal periodically
listener.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
n, remoteAddr, err := listener.ReadFrom(buf)
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue // Just a timeout, check for stop and try again
}
if s.directRelayStop != nil {
select {
case <-s.directRelayStop:
return // Stopped
default:
}
}
logger.Debug("Direct UDP relay read error: %v", err)
continue
}
// Get the source address
var srcAddrPort netip.AddrPort
if udpAddr, ok := remoteAddr.(*net.UDPAddr); ok {
srcAddrPort = udpAddr.AddrPort()
} else {
logger.Debug("Unexpected address type in relay: %T", remoteAddr)
continue
}
// Inject the packet directly into the SharedBind
if err := s.sharedBind.InjectPacket(buf[:n], srcAddrPort); err != nil {
logger.Debug("Failed to inject packet into SharedBind: %v", err)
continue
}
logger.Debug("Injected %d bytes from %s into SharedBind", n, srcAddrPort.String())
}
}
// StopDirectUDPRelay stops the direct UDP relay
func (s *WireGuardService) StopDirectUDPRelay() {
if s.directRelayStop != nil {
close(s.directRelayStop)
s.directRelayWg.Wait()
s.directRelayStop = nil
}
}
func (s *WireGuardService) LoadRemoteConfig() error { func (s *WireGuardService) LoadRemoteConfig() error {
if s.stopGetConfig != nil { if s.stopGetConfig != nil {
s.stopGetConfig() s.stopGetConfig()

View File

@@ -742,7 +742,8 @@ persistent_keepalive_interval=5`, util.FixKey(privateKey.String()), util.FixKey(
// } // }
} }
clientsAddProxyTarget(pm, wgData.TunnelIP) // Start direct UDP relay from main tunnel to clients' WireGuard (bypasses proxy)
clientsStartDirectRelay(wgData.TunnelIP)
if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil { if err := healthMonitor.AddTargets(wgData.HealthCheckTargets); err != nil {
logger.Error("Failed to bulk add health check targets: %v", err) logger.Error("Failed to bulk add health check targets: %v", err)

View File

@@ -32,3 +32,8 @@ func clientsAddProxyTargetNative(pm *proxy.ProxyManager, tunnelIp string) {
_ = tunnelIp _ = tunnelIp
// No-op for non-Linux systems // No-op for non-Linux systems
} }
func clientsStartDirectRelayNative(tunnelIP string) {
_ = tunnelIP
// No-op for non-Linux systems
}