Kind of working - revert if not

This commit is contained in:
Owen
2025-11-26 17:57:27 -05:00
parent d6edd6ca01
commit 5196effdb8
5 changed files with 210 additions and 27 deletions

View File

@@ -9,12 +9,19 @@ import (
"runtime"
"sync"
"sync/atomic"
"time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
wgConn "golang.zx2c4.com/wireguard/conn"
)
// injectedPacket represents a packet injected into the SharedBind from an internal source
type injectedPacket struct {
data []byte
endpoint wgConn.Endpoint
}
// Endpoint represents a network endpoint for the SharedBind
type Endpoint struct {
AddrPort netip.AddrPort
@@ -71,6 +78,9 @@ type SharedBind struct {
// Port binding information
port uint16
// Channel for injected packets (from direct relay)
injectedPackets chan injectedPacket
}
// New creates a new SharedBind from an existing UDP connection.
@@ -82,7 +92,8 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
}
bind := &SharedBind{
udpConn: udpConn,
udpConn: udpConn,
injectedPackets: make(chan injectedPacket, 256), // Buffer for injected packets
}
// Initialize reference count to 1 (the creator holds the first reference)
@@ -96,6 +107,30 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
return bind, nil
}
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
// This is used for direct relay from netstack without going through the host network.
// The fromAddr should be the address the packet appears to come from.
func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
if b.closed.Load() {
return net.ErrClosed
}
// Make a copy of the data to avoid issues with buffer reuse
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
select {
case b.injectedPackets <- injectedPacket{
data: dataCopy,
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
}:
return nil
default:
// Channel full, drop the packet
return fmt.Errorf("injected packet buffer full")
}
}
// AddRef increments the reference count. Call this when sharing
// the bind with another component.
func (b *SharedBind) AddRef() {
@@ -226,26 +261,54 @@ func (b *SharedBind) Open(uport uint16) ([]wgConn.ReceiveFunc, uint16, error) {
// makeReceiveIPv4 creates a receive function for IPv4 packets
func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
if b.closed.Load() {
return 0, net.ErrClosed
for {
if b.closed.Load() {
return 0, net.ErrClosed
}
// Check for injected packets first (non-blocking)
select {
case pkt := <-b.injectedPackets:
if len(pkt.data) <= len(bufs[0]) {
copy(bufs[0], pkt.data)
sizes[0] = len(pkt.data)
eps[0] = pkt.endpoint
return 1, nil
}
default:
// No injected packets, continue to check socket
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Set a short read deadline so we can poll for injected packets
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
var n int
var err error
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
n, err = b.receiveIPv4Batch(pc, bufs, sizes, eps)
} else {
n, err = b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout - loop back to check for injected packets
continue
}
return n, err
}
return n, nil
}
b.mu.RLock()
conn := b.udpConn
pc := b.ipv4PC
b.mu.RUnlock()
if conn == nil {
return 0, net.ErrClosed
}
// Use batch reading on Linux for performance
if pc != nil && (runtime.GOOS == "linux" || runtime.GOOS == "android") {
return b.receiveIPv4Batch(pc, bufs, sizes, eps)
}
// Fallback to simple read for other platforms
return b.receiveIPv4Simple(conn, bufs, sizes, eps)
}
}