diff --git a/bind/shared_bind.go b/bind/shared_bind.go index d6d967c..52f9fcc 100644 --- a/bind/shared_bind.go +++ b/bind/shared_bind.go @@ -9,7 +9,6 @@ import ( "runtime" "sync" "sync/atomic" - "time" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -100,15 +99,22 @@ type SharedBind struct { // Port binding information port uint16 - // Channel for packets from netstack (from direct relay) + // Channel for packets from netstack (from direct relay) - larger buffer for throughput netstackPackets chan injectedPacket // Netstack connection for sending responses back through the tunnel - netstackConn net.PacketConn - netstackMu sync.RWMutex + // Using atomic.Pointer for lock-free access in hot path + netstackConn atomic.Pointer[net.PacketConn] - // Track which endpoints came from netstack (key: AddrPort string, value: true) + // Track which endpoints came from netstack (key: netip.AddrPort, value: struct{}) + // Using netip.AddrPort directly as key is more efficient than string netstackEndpoints sync.Map + + // Pre-allocated message buffers for batch operations (Linux only) + ipv4Msgs []ipv4.Message + + // Shutdown signal for receive goroutines + closeChan chan struct{} } // New creates a new SharedBind from an existing UDP connection. @@ -121,7 +127,8 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) { bind := &SharedBind{ udpConn: udpConn, - netstackPackets: make(chan injectedPacket, 256), // Buffer for netstack packets + netstackPackets: make(chan injectedPacket, 1024), // Larger buffer for better throughput + closeChan: make(chan struct{}), } // Initialize reference count to 1 (the creator holds the first reference) @@ -138,16 +145,16 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) { // SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel. // This connection is used for relay traffic that should go back through the main tunnel. func (b *SharedBind) SetNetstackConn(conn net.PacketConn) { - b.netstackMu.Lock() - defer b.netstackMu.Unlock() - b.netstackConn = conn + b.netstackConn.Store(&conn) } // GetNetstackConn returns the netstack connection if set func (b *SharedBind) GetNetstackConn() net.PacketConn { - b.netstackMu.RLock() - defer b.netstackMu.RUnlock() - return b.netstackConn + ptr := b.netstackConn.Load() + if ptr == nil { + return nil + } + return *ptr } // InjectPacket allows injecting a packet directly into the SharedBind's receive path. @@ -159,7 +166,8 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error { } // Track this endpoint as coming from netstack so responses go back the same way - b.netstackEndpoints.Store(fromAddr.String(), true) + // Use AddrPort directly as key (more efficient than string) + b.netstackEndpoints.Store(fromAddr, struct{}{}) // Make a copy of the data to avoid issues with buffer reuse dataCopy := make([]byte, len(data)) @@ -171,6 +179,8 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error { endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr}, }: return nil + case <-b.closeChan: + return net.ErrClosed default: // Channel full, drop the packet return fmt.Errorf("netstack packet buffer full") @@ -212,6 +222,9 @@ func (b *SharedBind) closeConnection() error { return nil } + // Signal all goroutines to stop + close(b.closeChan) + b.mu.Lock() defer b.mu.Unlock() @@ -225,9 +238,7 @@ func (b *SharedBind) closeConnection() error { b.ipv6PC = nil // Clear netstack connection (but don't close it - it's managed externally) - b.netstackMu.Lock() - b.netstackConn = nil - b.netstackMu.Unlock() + b.netstackConn.Store(nil) // Clear tracked netstack endpoints b.netstackEndpoints = sync.Map{} @@ -238,9 +249,7 @@ func (b *SharedBind) closeConnection() error { // ClearNetstackConn clears the netstack connection and tracked endpoints. // Call this when stopping the relay. func (b *SharedBind) ClearNetstackConn() { - b.netstackMu.Lock() - b.netstackConn = nil - b.netstackMu.Unlock() + b.netstackConn.Store(nil) // Clear tracked netstack endpoints b.netstackEndpoints = sync.Map{} @@ -306,99 +315,96 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) { if runtime.GOOS == "linux" || runtime.GOOS == "android" { b.ipv4PC = ipv4.NewPacketConn(b.udpConn) b.ipv6PC = ipv6.NewPacketConn(b.udpConn) + + // Pre-allocate message buffers for batch operations + batchSize := wgConn.IdealBatchSize + b.ipv4Msgs = make([]ipv4.Message, batchSize) + for i := range b.ipv4Msgs { + b.ipv4Msgs[i].OOB = make([]byte, 0) + } } - // Create receive functions + // Create receive functions - one for socket, one for netstack recvFuncs := make([]wgConn.ReceiveFunc, 0, 2) - // Add IPv4 receive function - if b.ipv4PC != nil || runtime.GOOS != "linux" { - recvFuncs = append(recvFuncs, b.makeReceiveIPv4()) - } + // Add socket receive function (reads from physical UDP socket) + recvFuncs = append(recvFuncs, b.makeReceiveSocket()) - // Add IPv6 receive function if needed - // For now, we focus on IPv4 for hole punching use case + // Add netstack receive function (reads from injected packets channel) + recvFuncs = append(recvFuncs, b.makeReceiveNetstack()) b.recvFuncs = recvFuncs return recvFuncs, b.port, nil } -// makeReceiveIPv4 creates a receive function for IPv4 packets -func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc { +// makeReceiveSocket creates a receive function for physical UDP socket packets +func (b *SharedBind) makeReceiveSocket() wgConn.ReceiveFunc { return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { - for { - if b.closed.Load() { - return 0, net.ErrClosed + if b.closed.Load() { + return 0, net.ErrClosed + } + + b.mu.RLock() + conn := b.udpConn + pc := b.ipv4PC + b.mu.RUnlock() + + if conn == nil { + return 0, net.ErrClosed + } + + // Use batch reading on Linux for performance + if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { + return b.receiveIPv4Batch(pc, bufs, sizes, eps) + } + return b.receiveIPv4Simple(conn, bufs, sizes, eps) + } +} + +// makeReceiveNetstack creates a receive function for netstack-injected packets +func (b *SharedBind) makeReceiveNetstack() wgConn.ReceiveFunc { + return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) { + select { + case <-b.closeChan: + return 0, net.ErrClosed + case pkt := <-b.netstackPackets: + if len(pkt.data) <= len(bufs[0]) { + copy(bufs[0], pkt.data) + sizes[0] = len(pkt.data) + eps[0] = pkt.endpoint + return 1, nil } - - // Check for netstack packets first (non-blocking) - select { - case pkt := <-b.netstackPackets: - if len(pkt.data) <= len(bufs[0]) { - copy(bufs[0], pkt.data) - sizes[0] = len(pkt.data) - eps[0] = pkt.endpoint - return 1, nil - } - default: - // No netstack packets, continue to check socket - } - - b.mu.RLock() - conn := b.udpConn - pc := b.ipv4PC - b.mu.RUnlock() - - if conn == nil { - return 0, net.ErrClosed - } - - // Set a short read deadline so we can poll for netstack packets - conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond)) - - var n int - var err error - // Use batch reading on Linux for performance - if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") { - n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps) - } else { - n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps) - } - - if err != nil { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - // Timeout - loop back to check for netstack packets - continue - } - return n, err - } - return n, nil + // Packet too large for buffer, skip it + return 0, nil } } } // receiveIPv4Batch uses batch reading for better performance on Linux func (b *SharedBind) receiveIPv4Batch(pc *ipv4.PacketConn, bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (int, error) { - // Create messages for batch reading - msgs := make([]ipv4.Message, len(bufs)) - for i := range bufs { - msgs[i].Buffers = [][]byte{bufs[i]} - msgs[i].OOB = make([]byte, 0) // No OOB data needed for basic use + // Use pre-allocated messages, just update buffer pointers + numBufs := len(bufs) + if numBufs > len(b.ipv4Msgs) { + numBufs = len(b.ipv4Msgs) } - numMsgs, err := pc.ReadBatch(msgs, 0) + for i := 0; i < numBufs; i++ { + b.ipv4Msgs[i].Buffers = [][]byte{bufs[i]} + } + + numMsgs, err := pc.ReadBatch(b.ipv4Msgs[:numBufs], 0) if err != nil { return 0, err } for i := 0; i < numMsgs; i++ { - sizes[i] = msgs[i].N + sizes[i] = b.ipv4Msgs[i].N if sizes[i] == 0 { continue } - if msgs[i].Addr != nil { - if udpAddr, ok := msgs[i].Addr.(*net.UDPAddr); ok { + if b.ipv4Msgs[i].Addr != nil { + if udpAddr, ok := b.ipv4Msgs[i].Addr.(*net.UDPAddr); ok { addrPort := udpAddr.AddrPort() eps[i] = &wgConn.StdNetEndpoint{AddrPort: addrPort} } @@ -435,7 +441,7 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { // Extract the destination address from the endpoint var destAddrPort netip.AddrPort - // Try to cast to StdNetEndpoint first + // Try to cast to StdNetEndpoint first (most common case, avoid allocations) if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok { destAddrPort = stdEp.AddrPort } else { @@ -464,12 +470,11 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error { } // Check if this endpoint came from netstack - if so, send through netstack - if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort.String()); isNetstackEndpoint { - b.netstackMu.RLock() - netstackConn := b.netstackConn - b.netstackMu.RUnlock() - - if netstackConn != nil { + // Use AddrPort directly as key (more efficient than string conversion) + if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort); isNetstackEndpoint { + connPtr := b.netstackConn.Load() + if connPtr != nil && *connPtr != nil { + netstackConn := *connPtr destAddr := net.UDPAddrFromAddrPort(destAddrPort) // Send all buffers through netstack for _, buf := range bufs {