Speed much better!

This commit is contained in:
Owen
2025-11-30 11:24:50 -05:00
parent de96be810b
commit cdaff27964

View File

@@ -9,7 +9,6 @@ import (
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@@ -100,15 +99,22 @@ type SharedBind struct {
// Port binding information // Port binding information
port uint16 port uint16
// Channel for packets from netstack (from direct relay) // Channel for packets from netstack (from direct relay) - larger buffer for throughput
netstackPackets chan injectedPacket netstackPackets chan injectedPacket
// Netstack connection for sending responses back through the tunnel // Netstack connection for sending responses back through the tunnel
netstackConn net.PacketConn // Using atomic.Pointer for lock-free access in hot path
netstackMu sync.RWMutex netstackConn atomic.Pointer[net.PacketConn]
// Track which endpoints came from netstack (key: AddrPort string, value: true) // Track which endpoints came from netstack (key: netip.AddrPort, value: struct{})
// Using netip.AddrPort directly as key is more efficient than string
netstackEndpoints sync.Map netstackEndpoints sync.Map
// Pre-allocated message buffers for batch operations (Linux only)
ipv4Msgs []ipv4.Message
// Shutdown signal for receive goroutines
closeChan chan struct{}
} }
// New creates a new SharedBind from an existing UDP connection. // New creates a new SharedBind from an existing UDP connection.
@@ -121,7 +127,8 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
bind := &SharedBind{ bind := &SharedBind{
udpConn: udpConn, udpConn: udpConn,
netstackPackets: make(chan injectedPacket, 256), // Buffer for netstack packets netstackPackets: make(chan injectedPacket, 1024), // Larger buffer for better throughput
closeChan: make(chan struct{}),
} }
// Initialize reference count to 1 (the creator holds the first reference) // Initialize reference count to 1 (the creator holds the first reference)
@@ -138,16 +145,16 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel. // SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
// This connection is used for relay traffic that should go back through the main tunnel. // This connection is used for relay traffic that should go back through the main tunnel.
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) { func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
b.netstackMu.Lock() b.netstackConn.Store(&conn)
defer b.netstackMu.Unlock()
b.netstackConn = conn
} }
// GetNetstackConn returns the netstack connection if set // GetNetstackConn returns the netstack connection if set
func (b *SharedBind) GetNetstackConn() net.PacketConn { func (b *SharedBind) GetNetstackConn() net.PacketConn {
b.netstackMu.RLock() ptr := b.netstackConn.Load()
defer b.netstackMu.RUnlock() if ptr == nil {
return b.netstackConn return nil
}
return *ptr
} }
// InjectPacket allows injecting a packet directly into the SharedBind's receive path. // InjectPacket allows injecting a packet directly into the SharedBind's receive path.
@@ -159,7 +166,8 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
} }
// Track this endpoint as coming from netstack so responses go back the same way // Track this endpoint as coming from netstack so responses go back the same way
b.netstackEndpoints.Store(fromAddr.String(), true) // Use AddrPort directly as key (more efficient than string)
b.netstackEndpoints.Store(fromAddr, struct{}{})
// Make a copy of the data to avoid issues with buffer reuse // Make a copy of the data to avoid issues with buffer reuse
dataCopy := make([]byte, len(data)) dataCopy := make([]byte, len(data))
@@ -171,6 +179,8 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr}, endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
}: }:
return nil return nil
case <-b.closeChan:
return net.ErrClosed
default: default:
// Channel full, drop the packet // Channel full, drop the packet
return fmt.Errorf("netstack packet buffer full") return fmt.Errorf("netstack packet buffer full")
@@ -212,6 +222,9 @@ func (b *SharedBind) closeConnection() error {
return nil return nil
} }
// Signal all goroutines to stop
close(b.closeChan)
b.mu.Lock() b.mu.Lock()
defer b.mu.Unlock() defer b.mu.Unlock()
@@ -225,9 +238,7 @@ func (b *SharedBind) closeConnection() error {
b.ipv6PC = nil b.ipv6PC = nil
// Clear netstack connection (but don't close it - it's managed externally) // Clear netstack connection (but don't close it - it's managed externally)
b.netstackMu.Lock() b.netstackConn.Store(nil)
b.netstackConn = nil
b.netstackMu.Unlock()
// Clear tracked netstack endpoints // Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{} b.netstackEndpoints = sync.Map{}
@@ -238,9 +249,7 @@ func (b *SharedBind) closeConnection() error {
// ClearNetstackConn clears the netstack connection and tracked endpoints. // ClearNetstackConn clears the netstack connection and tracked endpoints.
// Call this when stopping the relay. // Call this when stopping the relay.
func (b *SharedBind) ClearNetstackConn() { func (b *SharedBind) ClearNetstackConn() {
b.netstackMu.Lock() b.netstackConn.Store(nil)
b.netstackConn = nil
b.netstackMu.Unlock()
// Clear tracked netstack endpoints // Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{} b.netstackEndpoints = sync.Map{}
@@ -306,99 +315,96 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
if runtime.GOOS == "linux" || runtime.GOOS == "android" { if runtime.GOOS == "linux" || runtime.GOOS == "android" {
b.ipv4PC = ipv4.NewPacketConn(b.udpConn) b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
b.ipv6PC = ipv6.NewPacketConn(b.udpConn) b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
// Pre-allocate message buffers for batch operations
batchSize := wgConn.IdealBatchSize
b.ipv4Msgs = make([]ipv4.Message, batchSize)
for i := range b.ipv4Msgs {
b.ipv4Msgs[i].OOB = make([]byte, 0)
}
} }
// Create receive functions // Create receive functions - one for socket, one for netstack
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2) recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
// Add IPv4 receive function // Add socket receive function (reads from physical UDP socket)
if b.ipv4PC != nil || runtime.GOOS != "linux" { recvFuncs = append(recvFuncs, b.makeReceiveSocket())
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
}
// Add IPv6 receive function if needed // Add netstack receive function (reads from injected packets channel)
// For now, we focus on IPv4 for hole punching use case recvFuncs = append(recvFuncs, b.makeReceiveNetstack())
b.recvFuncs = recvFuncs b.recvFuncs = recvFuncs
return recvFuncs, b.port, nil return recvFuncs, b.port, nil
} }
// makeReceiveIPv4 creates a receive function for IPv4 packets // makeReceiveSocket creates a receive function for physical UDP socket packets
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
for { if b.closed.Load() {
if b.closed.Load() { return 0, net.ErrClosed
return 0, net.ErrClosed }
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
}
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
}
// makeReceiveNetstack creates a receive function for netstack-injected packets
func (b *SharedBind) makeReceiveNetstack() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
select {
case <-b.closeChan:
return 0, net.ErrClosed
case pkt := <-b.netstackPackets:
if len(pkt.data) <= len(bufs[0]) {
copy(bufs[0], pkt.data)
sizes[0] = len(pkt.data)
eps[0] = pkt.endpoint
return 1, nil
} }
// Packet too large for buffer, skip it
// Check for netstack packets first (non-blocking) return 0, nil
select {
case pkt := <-b.netstackPackets:
if len(pkt.data) <= len(bufs[0]) {
copy(bufs[0], pkt.data)
sizes[0] = len(pkt.data)
eps[0] = pkt.endpoint
return 1, nil
}
default:
// No netstack packets, continue to check socket
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Set a short read deadline so we can poll for netstack packets
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
var n int
var err error
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps)
} else {
n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout - loop back to check for netstack packets
continue
}
return n, err
}
return n, nil
} }
} }
} }
// receiveIPv4Batch uses batch reading for better performance on Linux // receiveIPv4Batch uses batch reading for better performance on Linux
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
// Create messages for batch reading // Use pre-allocated messages, just update buffer pointers
msgs := make([]ipv4.Message, len(bufs)) numBufs := len(bufs)
for i := range bufs { if numBufs > len(b.ipv4Msgs) {
msgs[i].Buffers = [][]byte{bufs[i]} numBufs = len(b.ipv4Msgs)
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
} }
numMsgs, err := pc.ReadBatch(msgs, 0) for i := 0; i < numBufs; i++ {
b.ipv4Msgs[i].Buffers = [][]byte{bufs[i]}
}
numMsgs, err := pc.ReadBatch(b.ipv4Msgs[:numBufs], 0)
if err != nil { if err != nil {
return 0, err return 0, err
} }
for i := 0; i < numMsgs; i++ { for i := 0; i < numMsgs; i++ {
sizes[i] = msgs[i].N sizes[i] = b.ipv4Msgs[i].N
if sizes[i] == 0 { if sizes[i] == 0 {
continue continue
} }
if msgs[i].Addr != nil { if b.ipv4Msgs[i].Addr != nil {
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok { if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok {
addrPort := udpAddr.AddrPort() addrPort := udpAddr.AddrPort()
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort} eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
} }
@@ -435,7 +441,7 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
// Extract the destination address from the endpoint // Extract the destination address from the endpoint
var destAddrPort netip.AddrPort var destAddrPort netip.AddrPort
// Try to cast to StdNetEndpoint first // Try to cast to StdNetEndpoint first (most common case, avoid allocations)
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok { if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddrPort = stdEp.AddrPort destAddrPort = stdEp.AddrPort
} else { } else {
@@ -464,12 +470,11 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
} }
// Check if this endpoint came from netstack - if so, send through netstack // Check if this endpoint came from netstack - if so, send through netstack
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort.String()); isNetstackEndpoint { // Use AddrPort directly as key (more efficient than string conversion)
b.netstackMu.RLock() if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort); isNetstackEndpoint {
netstackConn := b.netstackConn connPtr := b.netstackConn.Load()
b.netstackMu.RUnlock() if connPtr != nil && *connPtr != nil {
netstackConn := *connPtr
if netstackConn != nil {
destAddr := net.UDPAddrFromAddrPort(destAddrPort) destAddr := net.UDPAddrFromAddrPort(destAddrPort)
// Send all buffers through netstack // Send all buffers through netstack
for _, buf := range bufs { for _, buf := range bufs {