Working but no wgtester? - revert if bad

This commit is contained in:
Owen
2025-11-29 17:38:34 -05:00
parent 5196effdb8
commit de96be810b
3 changed files with 332 additions and 34 deletions

View File

@@ -16,6 +16,25 @@ import (
wgConn "golang.zx2c4.com/wireguard/conn"
)
// PacketSource identifies where a packet came from
type PacketSource uint8
const (
SourceSocket PacketSource = iota // From physical UDP socket (hole-punched clients)
SourceNetstack // From netstack (relay through main tunnel)
)
// SourceAwareEndpoint wraps an endpoint with source information
type SourceAwareEndpoint struct {
wgConn.Endpoint
source PacketSource
}
// GetSource returns the source of this endpoint
func (e *SourceAwareEndpoint) GetSource() PacketSource {
return e.source
}
// injectedPacket represents a packet injected into the SharedBind from an internal source
type injectedPacket struct {
data []byte
@@ -59,10 +78,12 @@ func (e *Endpoint) SrcToString() string {
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
// and hole punch senders. It wraps a single UDP connection and implements
// reference counting to prevent premature closure.
// It also supports receiving packets from a netstack and routing responses
// back through the appropriate source.
type SharedBind struct {
mu sync.RWMutex
// The underlying UDP connection
// The underlying UDP connection (for hole-punched clients)
udpConn *net.UDPConn
// IPv4 and IPv6 packet connections for advanced features
@@ -79,8 +100,15 @@ type SharedBind struct {
// Port binding information
port uint16
// Channel for injected packets (from direct relay)
injectedPackets chan injectedPacket
// Channel for packets from netstack (from direct relay)
netstackPackets chan injectedPacket
// Netstack connection for sending responses back through the tunnel
netstackConn net.PacketConn
netstackMu sync.RWMutex
// Track which endpoints came from netstack (key: AddrPort string, value: true)
netstackEndpoints sync.Map
}
// New creates a new SharedBind from an existing UDP connection.
@@ -93,7 +121,7 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
bind := &SharedBind{
udpConn: udpConn,
injectedPackets: make(chan injectedPacket, 256), // Buffer for injected packets
netstackPackets: make(chan injectedPacket, 256), // Buffer for netstack packets
}
// Initialize reference count to 1 (the creator holds the first reference)
@@ -107,6 +135,21 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
return bind, nil
}
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
// This connection is used for relay traffic that should go back through the main tunnel.
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
b.netstackMu.Lock()
defer b.netstackMu.Unlock()
b.netstackConn = conn
}
// GetNetstackConn returns the netstack connection if set
func (b *SharedBind) GetNetstackConn() net.PacketConn {
b.netstackMu.RLock()
defer b.netstackMu.RUnlock()
return b.netstackConn
}
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
// This is used for direct relay from netstack without going through the host network.
// The fromAddr should be the address the packet appears to come from.
@@ -115,19 +158,22 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
return net.ErrClosed
}
// Track this endpoint as coming from netstack so responses go back the same way
b.netstackEndpoints.Store(fromAddr.String(), true)
// Make a copy of the data to avoid issues with buffer reuse
dataCopy := make([]byte, len(data))
copy(dataCopy, data)
select {
case b.injectedPackets <- injectedPacket{
case b.netstackPackets <- injectedPacket{
data: dataCopy,
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
}:
return nil
default:
// Channel full, drop the packet
return fmt.Errorf("injected packet buffer full")
return fmt.Errorf("netstack packet buffer full")
}
}
@@ -178,9 +224,28 @@ func (b *SharedBind) closeConnection() error {
b.ipv4PC = nil
b.ipv6PC = nil
// Clear netstack connection (but don't close it - it's managed externally)
b.netstackMu.Lock()
b.netstackConn = nil
b.netstackMu.Unlock()
// Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{}
return err
}
// ClearNetstackConn clears the netstack connection and tracked endpoints.
// Call this when stopping the relay.
func (b *SharedBind) ClearNetstackConn() {
b.netstackMu.Lock()
b.netstackConn = nil
b.netstackMu.Unlock()
// Clear tracked netstack endpoints
b.netstackEndpoints = sync.Map{}
}
// GetUDPConn returns the underlying UDP connection.
// The caller must not close this connection directly.
func (b *SharedBind) GetUDPConn() *net.UDPConn {
@@ -266,9 +331,9 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return 0, net.ErrClosed
}
// Check for injected packets first (non-blocking)
// Check for netstack packets first (non-blocking)
select {
case pkt := <-b.injectedPackets:
case pkt := <-b.netstackPackets:
if len(pkt.data) <= len(bufs[0]) {
copy(bufs[0], pkt.data)
sizes[0] = len(pkt.data)
@@ -276,7 +341,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return 1, nil
}
default:
// No injected packets, continue to check socket
// No netstack packets, continue to check socket
}
b.mu.RLock()
@@ -288,7 +353,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
return 0, net.ErrClosed
}
// Set a short read deadline so we can poll for injected packets
// Set a short read deadline so we can poll for netstack packets
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
var n int
@@ -302,7 +367,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
// Timeout - loop back to check for injected packets
// Timeout - loop back to check for netstack packets
continue
}
return n, err
@@ -360,26 +425,19 @@ func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes [
}
// Send implements the WireGuard Bind interface.
// It sends packets to the specified endpoint.
// It sends packets to the specified endpoint, routing through the appropriate
// source (netstack or physical socket) based on where the endpoint's packets came from.
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
if b.closed.Load() {
return net.ErrClosed
}
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
// Extract the destination address from the endpoint
var destAddr *net.UDPAddr
var destAddrPort netip.AddrPort
// Try to cast to StdNetEndpoint first
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
destAddrPort = stdEp.AddrPort
} else {
// Fallback: construct from DstIP and DstToBytes
dstBytes := ep.DstToBytes()
@@ -396,15 +454,46 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
}
if addr.IsValid() {
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
destAddrPort = netip.AddrPortFrom(addr, port)
}
}
}
if destAddr == nil {
if !destAddrPort.IsValid() {
return fmt.Errorf("could not extract destination address from endpoint")
}
// Check if this endpoint came from netstack - if so, send through netstack
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort.String()); isNetstackEndpoint {
b.netstackMu.RLock()
netstackConn := b.netstackConn
b.netstackMu.RUnlock()
if netstackConn != nil {
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
// Send all buffers through netstack
for _, buf := range bufs {
_, err := netstackConn.WriteTo(buf, destAddr)
if err != nil {
return err
}
}
return nil
}
// Fall through to socket if netstack conn not available
}
// Send through the physical UDP socket (for hole-punched clients)
b.mu.RLock()
conn := b.udpConn
b.mu.RUnlock()
if conn == nil {
return net.ErrClosed
}
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
// Send all buffers to the destination
for _, buf := range bufs {
_, err := conn.WriteToUDP(buf, destAddr)