mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
Working but no wgtester? - revert if bad
This commit is contained in:
@@ -16,6 +16,25 @@ import (
|
||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
// PacketSource identifies where a packet came from
|
||||
type PacketSource uint8
|
||||
|
||||
const (
|
||||
SourceSocket PacketSource = iota // From physical UDP socket (hole-punched clients)
|
||||
SourceNetstack // From netstack (relay through main tunnel)
|
||||
)
|
||||
|
||||
// SourceAwareEndpoint wraps an endpoint with source information
|
||||
type SourceAwareEndpoint struct {
|
||||
wgConn.Endpoint
|
||||
source PacketSource
|
||||
}
|
||||
|
||||
// GetSource returns the source of this endpoint
|
||||
func (e *SourceAwareEndpoint) GetSource() PacketSource {
|
||||
return e.source
|
||||
}
|
||||
|
||||
// injectedPacket represents a packet injected into the SharedBind from an internal source
|
||||
type injectedPacket struct {
|
||||
data []byte
|
||||
@@ -59,10 +78,12 @@ func (e *Endpoint) SrcToString() string {
|
||||
// SharedBind is a thread-safe UDP bind that can be shared between WireGuard
|
||||
// and hole punch senders. It wraps a single UDP connection and implements
|
||||
// reference counting to prevent premature closure.
|
||||
// It also supports receiving packets from a netstack and routing responses
|
||||
// back through the appropriate source.
|
||||
type SharedBind struct {
|
||||
mu sync.RWMutex
|
||||
|
||||
// The underlying UDP connection
|
||||
// The underlying UDP connection (for hole-punched clients)
|
||||
udpConn *net.UDPConn
|
||||
|
||||
// IPv4 and IPv6 packet connections for advanced features
|
||||
@@ -79,8 +100,15 @@ type SharedBind struct {
|
||||
// Port binding information
|
||||
port uint16
|
||||
|
||||
// Channel for injected packets (from direct relay)
|
||||
injectedPackets chan injectedPacket
|
||||
// Channel for packets from netstack (from direct relay)
|
||||
netstackPackets chan injectedPacket
|
||||
|
||||
// Netstack connection for sending responses back through the tunnel
|
||||
netstackConn net.PacketConn
|
||||
netstackMu sync.RWMutex
|
||||
|
||||
// Track which endpoints came from netstack (key: AddrPort string, value: true)
|
||||
netstackEndpoints sync.Map
|
||||
}
|
||||
|
||||
// New creates a new SharedBind from an existing UDP connection.
|
||||
@@ -93,7 +121,7 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||
|
||||
bind := &SharedBind{
|
||||
udpConn: udpConn,
|
||||
injectedPackets: make(chan injectedPacket, 256), // Buffer for injected packets
|
||||
netstackPackets: make(chan injectedPacket, 256), // Buffer for netstack packets
|
||||
}
|
||||
|
||||
// Initialize reference count to 1 (the creator holds the first reference)
|
||||
@@ -107,6 +135,21 @@ func New(udpConn *net.UDPConn) (*SharedBind, error) {
|
||||
return bind, nil
|
||||
}
|
||||
|
||||
// SetNetstackConn sets the netstack connection for receiving/sending packets through the tunnel.
|
||||
// This connection is used for relay traffic that should go back through the main tunnel.
|
||||
func (b *SharedBind) SetNetstackConn(conn net.PacketConn) {
|
||||
b.netstackMu.Lock()
|
||||
defer b.netstackMu.Unlock()
|
||||
b.netstackConn = conn
|
||||
}
|
||||
|
||||
// GetNetstackConn returns the netstack connection if set
|
||||
func (b *SharedBind) GetNetstackConn() net.PacketConn {
|
||||
b.netstackMu.RLock()
|
||||
defer b.netstackMu.RUnlock()
|
||||
return b.netstackConn
|
||||
}
|
||||
|
||||
// InjectPacket allows injecting a packet directly into the SharedBind's receive path.
|
||||
// This is used for direct relay from netstack without going through the host network.
|
||||
// The fromAddr should be the address the packet appears to come from.
|
||||
@@ -115,19 +158,22 @@ func (b *SharedBind) InjectPacket(data []byte, fromAddr netip.AddrPort) error {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
// Track this endpoint as coming from netstack so responses go back the same way
|
||||
b.netstackEndpoints.Store(fromAddr.String(), true)
|
||||
|
||||
// Make a copy of the data to avoid issues with buffer reuse
|
||||
dataCopy := make([]byte, len(data))
|
||||
copy(dataCopy, data)
|
||||
|
||||
select {
|
||||
case b.injectedPackets <- injectedPacket{
|
||||
case b.netstackPackets <- injectedPacket{
|
||||
data: dataCopy,
|
||||
endpoint: &wgConn.StdNetEndpoint{AddrPort: fromAddr},
|
||||
}:
|
||||
return nil
|
||||
default:
|
||||
// Channel full, drop the packet
|
||||
return fmt.Errorf("injected packet buffer full")
|
||||
return fmt.Errorf("netstack packet buffer full")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,9 +224,28 @@ func (b *SharedBind) closeConnection() error {
|
||||
b.ipv4PC = nil
|
||||
b.ipv6PC = nil
|
||||
|
||||
// Clear netstack connection (but don't close it - it's managed externally)
|
||||
b.netstackMu.Lock()
|
||||
b.netstackConn = nil
|
||||
b.netstackMu.Unlock()
|
||||
|
||||
// Clear tracked netstack endpoints
|
||||
b.netstackEndpoints = sync.Map{}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// ClearNetstackConn clears the netstack connection and tracked endpoints.
|
||||
// Call this when stopping the relay.
|
||||
func (b *SharedBind) ClearNetstackConn() {
|
||||
b.netstackMu.Lock()
|
||||
b.netstackConn = nil
|
||||
b.netstackMu.Unlock()
|
||||
|
||||
// Clear tracked netstack endpoints
|
||||
b.netstackEndpoints = sync.Map{}
|
||||
}
|
||||
|
||||
// GetUDPConn returns the underlying UDP connection.
|
||||
// The caller must not close this connection directly.
|
||||
func (b *SharedBind) GetUDPConn() *net.UDPConn {
|
||||
@@ -266,9 +331,9 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Check for injected packets first (non-blocking)
|
||||
// Check for netstack packets first (non-blocking)
|
||||
select {
|
||||
case pkt := <-b.injectedPackets:
|
||||
case pkt := <-b.netstackPackets:
|
||||
if len(pkt.data) <= len(bufs[0]) {
|
||||
copy(bufs[0], pkt.data)
|
||||
sizes[0] = len(pkt.data)
|
||||
@@ -276,7 +341,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||
return 1, nil
|
||||
}
|
||||
default:
|
||||
// No injected packets, continue to check socket
|
||||
// No netstack packets, continue to check socket
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
@@ -288,7 +353,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
|
||||
// Set a short read deadline so we can poll for injected packets
|
||||
// Set a short read deadline so we can poll for netstack packets
|
||||
conn.SetReadDeadline(time.Now().Add(50 * time.Millisecond))
|
||||
|
||||
var n int
|
||||
@@ -302,7 +367,7 @@ func (b *SharedBind) makeReceiveIPv4() wgConn.ReceiveFunc {
|
||||
|
||||
if err != nil {
|
||||
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
|
||||
// Timeout - loop back to check for injected packets
|
||||
// Timeout - loop back to check for netstack packets
|
||||
continue
|
||||
}
|
||||
return n, err
|
||||
@@ -360,26 +425,19 @@ func (b *SharedBind) receiveIPv4Simple(conn *net.UDPConn, bufs [][]byte, sizes [
|
||||
}
|
||||
|
||||
// Send implements the WireGuard Bind interface.
|
||||
// It sends packets to the specified endpoint.
|
||||
// It sends packets to the specified endpoint, routing through the appropriate
|
||||
// source (netstack or physical socket) based on where the endpoint's packets came from.
|
||||
func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
if b.closed.Load() {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
// Extract the destination address from the endpoint
|
||||
var destAddr *net.UDPAddr
|
||||
var destAddrPort netip.AddrPort
|
||||
|
||||
// Try to cast to StdNetEndpoint first
|
||||
if stdEp, ok := ep.(*wgConn.StdNetEndpoint); ok {
|
||||
destAddr = net.UDPAddrFromAddrPort(stdEp.AddrPort)
|
||||
destAddrPort = stdEp.AddrPort
|
||||
} else {
|
||||
// Fallback: construct from DstIP and DstToBytes
|
||||
dstBytes := ep.DstToBytes()
|
||||
@@ -396,15 +454,46 @@ func (b *SharedBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||
}
|
||||
|
||||
if addr.IsValid() {
|
||||
destAddr = net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, port))
|
||||
destAddrPort = netip.AddrPortFrom(addr, port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if destAddr == nil {
|
||||
if !destAddrPort.IsValid() {
|
||||
return fmt.Errorf("could not extract destination address from endpoint")
|
||||
}
|
||||
|
||||
// Check if this endpoint came from netstack - if so, send through netstack
|
||||
if _, isNetstackEndpoint := b.netstackEndpoints.Load(destAddrPort.String()); isNetstackEndpoint {
|
||||
b.netstackMu.RLock()
|
||||
netstackConn := b.netstackConn
|
||||
b.netstackMu.RUnlock()
|
||||
|
||||
if netstackConn != nil {
|
||||
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
|
||||
// Send all buffers through netstack
|
||||
for _, buf := range bufs {
|
||||
_, err := netstackConn.WriteTo(buf, destAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Fall through to socket if netstack conn not available
|
||||
}
|
||||
|
||||
// Send through the physical UDP socket (for hole-punched clients)
|
||||
b.mu.RLock()
|
||||
conn := b.udpConn
|
||||
b.mu.RUnlock()
|
||||
|
||||
if conn == nil {
|
||||
return net.ErrClosed
|
||||
}
|
||||
|
||||
destAddr := net.UDPAddrFromAddrPort(destAddrPort)
|
||||
|
||||
// Send all buffers to the destination
|
||||
for _, buf := range bufs {
|
||||
_, err := conn.WriteToUDP(buf, destAddr)
|
||||
|
||||
Reference in New Issue
Block a user