mirror of
https://github.com/fosrl/newt.git
synced 2026-03-08 03:36:40 +00:00
Speed much better!
This commit is contained in:
@@ -9,7 +9,6 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@@ -100,15 +99,22 @@ type SharedBind struct {
|
|||||||
// Port binding information
|
// Port binding information
|
||||||
port uint16
|
port uint16
|
||||||
|
|
||||||
// Channel for packets from netstack (from direct relay)
|
// Channel for packets from netstack (from direct relay) - larger buffer for throughput
|
||||||
netstackPackets chan injectedPacket
|
netstackPackets chan injectedPacket
|
||||||
|
|
||||||
// Netstack connection for sending responses back through the tunnel
|
// Netstack connection for sending responses back through the tunnel
|
||||||
netstackConn net.PacketConn
|
// Using atomic.Pointer for lock-free access in hot path
|
||||||
netstackMu sync.RWMutex
|
netstackConn atomic.Pointer[net.PacketConn]
|
||||||
|
|
||||||
// Track which endpoints came from netstack (key: AddrPort string, value: true)
|
// Track which endpoints came from netstack (key: netip.AddrPort, value: struct{})
|
||||||
|
// Using netip.AddrPort directly as key is more efficient than string
|
||||||
netstackEndpoints sync.Map
|
netstackEndpoints sync.Map
|
||||||
|
|
||||||
|
// Pre-allocated message buffers for batch operations (Linux only)
|
||||||
|
ipv4Msgs []ipv4.Message
|
||||||
|
|
||||||
|
// Shutdown signal for receive goroutines
|
||||||
|
closeChan chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new SharedBind from an existing UDP connection.
|
// New creates a new SharedBind from an existing UDP connection.
|
||||||
@@ -121,7 +127,8 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
|||||||
|
|
||||||
bind := &SharedBind{
|
bind := &SharedBind{
|
||||||
udpConn: udpConn,
|
udpConn: udpConn,
|
||||||
netstackPackets: make(chan injectedPacket, 256), // Buffer for netstack packets
|
netstackPackets: make(chan injectedPacket, 1024), // Larger buffer for better throughput
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize reference count to 1 (the creator holds the first reference)
|
// Initialize reference count to 1 (the creator holds the first reference)
|
||||||
@@ -138,16 +145,16 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
|||||||
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
|
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
|
||||||
// This connection is used for relay traffic that should go back through the main tunnel.
|
// This connection is used for relay traffic that should go back through the main tunnel.
|
||||||
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
|
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
|
||||||
b.netstackMu.Lock()
|
b.netstackConn.Store(&conn)
|
||||||
defer b.netstackMu.Unlock()
|
|
||||||
b.netstackConn = conn
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetNetstackConn returns the netstack connection if set
|
// GetNetstackConn returns the netstack connection if set
|
||||||
func (b *SharedBind) GetNetstackConn() net.PacketConn {
|
func (b *SharedBind) GetNetstackConn() net.PacketConn {
|
||||||
b.netstackMu.RLock()
|
ptr := b.netstackConn.Load()
|
||||||
defer b.netstackMu.RUnlock()
|
if ptr == nil {
|
||||||
return b.netstackConn
|
return nil
|
||||||
|
}
|
||||||
|
return *ptr
|
||||||
}
|
}
|
||||||
|
|
||||||
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
|
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
|
||||||
@@ -159,7 +166,8 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Track this endpoint as coming from netstack so responses go back the same way
|
// Track this endpoint as coming from netstack so responses go back the same way
|
||||||
b.netstackEndpoints.Store(fromAddr.String(), true)
|
// Use AddrPort directly as key (more efficient than string)
|
||||||
|
b.netstackEndpoints.Store(fromAddr, struct{}{})
|
||||||
|
|
||||||
// Make a copy of the data to avoid issues with buffer reuse
|
// Make a copy of the data to avoid issues with buffer reuse
|
||||||
dataCopy := make([]byte, len(data))
|
dataCopy := make([]byte, len(data))
|
||||||
@@ -171,6 +179,8 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
|
|||||||
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
|
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
|
||||||
}:
|
}:
|
||||||
return nil
|
return nil
|
||||||
|
case <-b.closeChan:
|
||||||
|
return net.ErrClosed
|
||||||
default:
|
default:
|
||||||
// Channel full, drop the packet
|
// Channel full, drop the packet
|
||||||
return fmt.Errorf("netstack packet buffer full")
|
return fmt.Errorf("netstack packet buffer full")
|
||||||
@@ -212,6 +222,9 @@ func (b *SharedBind) closeConnection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Signal all goroutines to stop
|
||||||
|
close(b.closeChan)
|
||||||
|
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
|
|
||||||
@@ -225,9 +238,7 @@ func (b *SharedBind) closeConnection() error {
|
|||||||
b.ipv6PC = nil
|
b.ipv6PC = nil
|
||||||
|
|
||||||
// Clear netstack connection (but don't close it - it's managed externally)
|
// Clear netstack connection (but don't close it - it's managed externally)
|
||||||
b.netstackMu.Lock()
|
b.netstackConn.Store(nil)
|
||||||
b.netstackConn = nil
|
|
||||||
b.netstackMu.Unlock()
|
|
||||||
|
|
||||||
// Clear tracked netstack endpoints
|
// Clear tracked netstack endpoints
|
||||||
b.netstackEndpoints = sync.Map{}
|
b.netstackEndpoints = sync.Map{}
|
||||||
@@ -238,9 +249,7 @@ func (b *SharedBind) closeConnection() error {
|
|||||||
// ClearNetstackConn clears the netstack connection and tracked endpoints.
|
// ClearNetstackConn clears the netstack connection and tracked endpoints.
|
||||||
// Call this when stopping the relay.
|
// Call this when stopping the relay.
|
||||||
func (b *SharedBind) ClearNetstackConn() {
|
func (b *SharedBind) ClearNetstackConn() {
|
||||||
b.netstackMu.Lock()
|
b.netstackConn.Store(nil)
|
||||||
b.netstackConn = nil
|
|
||||||
b.netstackMu.Unlock()
|
|
||||||
|
|
||||||
// Clear tracked netstack endpoints
|
// Clear tracked netstack endpoints
|
||||||
b.netstackEndpoints = sync.Map{}
|
b.netstackEndpoints = sync.Map{}
|
||||||
@@ -306,99 +315,96 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
|
|||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
||||||
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
b.ipv4PC = ipv4.NewPacketConn(b.udpConn)
|
||||||
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
b.ipv6PC = ipv6.NewPacketConn(b.udpConn)
|
||||||
|
|
||||||
|
// Pre-allocate message buffers for batch operations
|
||||||
|
batchSize := wgConn.IdealBatchSize
|
||||||
|
b.ipv4Msgs = make([]ipv4.Message, batchSize)
|
||||||
|
for i := range b.ipv4Msgs {
|
||||||
|
b.ipv4Msgs[i].OOB = make([]byte, 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create receive functions
|
// Create receive functions - one for socket, one for netstack
|
||||||
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
recvFuncs := make([]wgConn.ReceiveFunc, 0, 2)
|
||||||
|
|
||||||
// Add IPv4 receive function
|
// Add socket receive function (reads from physical UDP socket)
|
||||||
if b.ipv4PC != nil || runtime.GOOS != "linux" {
|
recvFuncs = append(recvFuncs, b.makeReceiveSocket())
|
||||||
recvFuncs = append(recvFuncs, b.makeReceiveIPv4())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add IPv6 receive function if needed
|
// Add netstack receive function (reads from injected packets channel)
|
||||||
// For now, we focus on IPv4 for hole punching use case
|
recvFuncs = append(recvFuncs, b.makeReceiveNetstack())
|
||||||
|
|
||||||
b.recvFuncs = recvFuncs
|
b.recvFuncs = recvFuncs
|
||||||
return recvFuncs, b.port, nil
|
return recvFuncs, b.port, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeReceiveIPv4 creates a receive function for IPv4 packets
|
// makeReceiveSocket creates a receive function for physical UDP socket packets
|
||||||
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc {
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
for {
|
if b.closed.Load() {
|
||||||
if b.closed.Load() {
|
return 0, net.ErrClosed
|
||||||
return 0, net.ErrClosed
|
}
|
||||||
|
|
||||||
|
b.mu.RLock()
|
||||||
|
conn := b.udpConn
|
||||||
|
pc := b.ipv4PC
|
||||||
|
b.mu.RUnlock()
|
||||||
|
|
||||||
|
if conn == nil {
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use batch reading on Linux for performance
|
||||||
|
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
||||||
|
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeReceiveNetstack creates a receive function for netstack-injected packets
|
||||||
|
func (b *SharedBind) makeReceiveNetstack() wgConn.ReceiveFunc {
|
||||||
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
|
select {
|
||||||
|
case <-b.closeChan:
|
||||||
|
return 0, net.ErrClosed
|
||||||
|
case pkt := <-b.netstackPackets:
|
||||||
|
if len(pkt.data) <= len(bufs[0]) {
|
||||||
|
copy(bufs[0], pkt.data)
|
||||||
|
sizes[0] = len(pkt.data)
|
||||||
|
eps[0] = pkt.endpoint
|
||||||
|
return 1, nil
|
||||||
}
|
}
|
||||||
|
// Packet too large for buffer, skip it
|
||||||
// Check for netstack packets first (non-blocking)
|
return 0, nil
|
||||||
select {
|
|
||||||
case pkt := <-b.netstackPackets:
|
|
||||||
if len(pkt.data) <= len(bufs[0]) {
|
|
||||||
copy(bufs[0], pkt.data)
|
|
||||||
sizes[0] = len(pkt.data)
|
|
||||||
eps[0] = pkt.endpoint
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
// No netstack packets, continue to check socket
|
|
||||||
}
|
|
||||||
|
|
||||||
b.mu.RLock()
|
|
||||||
conn := b.udpConn
|
|
||||||
pc := b.ipv4PC
|
|
||||||
b.mu.RUnlock()
|
|
||||||
|
|
||||||
if conn == nil {
|
|
||||||
return 0, net.ErrClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set a short read deadline so we can poll for netstack packets
|
|
||||||
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
|
|
||||||
|
|
||||||
var n int
|
|
||||||
var err error
|
|
||||||
// Use batch reading on Linux for performance
|
|
||||||
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
|
|
||||||
n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps)
|
|
||||||
} else {
|
|
||||||
n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
|
||||||
// Timeout - loop back to check for netstack packets
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// receiveIPv4Batch uses batch reading for better performance on Linux
|
// receiveIPv4Batch uses batch reading for better performance on Linux
|
||||||
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) {
|
||||||
// Create messages for batch reading
|
// Use pre-allocated messages, just update buffer pointers
|
||||||
msgs := make([]ipv4.Message, len(bufs))
|
numBufs := len(bufs)
|
||||||
for i := range bufs {
|
if numBufs > len(b.ipv4Msgs) {
|
||||||
msgs[i].Buffers = [][]byte{bufs[i]}
|
numBufs = len(b.ipv4Msgs)
|
||||||
msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use
|
|
||||||
}
|
}
|
||||||
|
|
||||||
numMsgs, err := pc.ReadBatch(msgs, 0)
|
for i := 0; i < numBufs; i++ {
|
||||||
|
b.ipv4Msgs[i].Buffers = [][]byte{bufs[i]}
|
||||||
|
}
|
||||||
|
|
||||||
|
numMsgs, err := pc.ReadBatch(b.ipv4Msgs[:numBufs], 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
for i := 0; i < numMsgs; i++ {
|
||||||
sizes[i] = msgs[i].N
|
sizes[i] = b.ipv4Msgs[i].N
|
||||||
if sizes[i] == 0 {
|
if sizes[i] == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if msgs[i].Addr != nil {
|
if b.ipv4Msgs[i].Addr != nil {
|
||||||
if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok {
|
if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok {
|
||||||
addrPort := udpAddr.AddrPort()
|
addrPort := udpAddr.AddrPort()
|
||||||
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort}
|
||||||
}
|
}
|
||||||
@@ -435,7 +441,7 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
|||||||
// Extract the destination address from the endpoint
|
// Extract the destination address from the endpoint
|
||||||
var destAddrPort netip.AddrPort
|
var destAddrPort netip.AddrPort
|
||||||
|
|
||||||
// Try to cast to StdNetEndpoint first
|
// Try to cast to StdNetEndpoint first (most common case, avoid allocations)
|
||||||
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||||
destAddrPort = stdEp.AddrPort
|
destAddrPort = stdEp.AddrPort
|
||||||
} else {
|
} else {
|
||||||
@@ -464,12 +470,11 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if this endpoint came from netstack - if so, send through netstack
|
// Check if this endpoint came from netstack - if so, send through netstack
|
||||||
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort.String()); isNetstackEndpoint {
|
// Use AddrPort directly as key (more efficient than string conversion)
|
||||||
b.netstackMu.RLock()
|
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort); isNetstackEndpoint {
|
||||||
netstackConn := b.netstackConn
|
connPtr := b.netstackConn.Load()
|
||||||
b.netstackMu.RUnlock()
|
if connPtr != nil && *connPtr != nil {
|
||||||
|
netstackConn := *connPtr
|
||||||
if netstackConn != nil {
|
|
||||||
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
|
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
|
||||||
// Send all buffers through netstack
|
// Send all buffers through netstack
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
|
|||||||
Reference in New Issue
Block a user