mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Add non-root ICMP support to userspace firewall forwarder (#4792)
This commit is contained in:
@@ -2,6 +2,7 @@ package forwarder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"gvisor.dev/gvisor/pkg/tcpip"
|
||||
@@ -16,7 +17,7 @@ type endpoint struct {
|
||||
logger *nblog.Logger
|
||||
dispatcher stack.NetworkDispatcher
|
||||
device *wgdevice.Device
|
||||
mtu uint32
|
||||
mtu atomic.Uint32
|
||||
}
|
||||
|
||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||
@@ -28,7 +29,7 @@ func (e *endpoint) IsAttached() bool {
|
||||
}
|
||||
|
||||
func (e *endpoint) MTU() uint32 {
|
||||
return e.mtu
|
||||
return e.mtu.Load()
|
||||
}
|
||||
|
||||
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||
@@ -82,6 +83,22 @@ func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (e *endpoint) Close() {
|
||||
// Endpoint cleanup - nothing to do as device is managed externally
|
||||
}
|
||||
|
||||
func (e *endpoint) SetLinkAddress(tcpip.LinkAddress) {
|
||||
// Link address is not used for this endpoint type
|
||||
}
|
||||
|
||||
func (e *endpoint) SetMTU(mtu uint32) {
|
||||
e.mtu.Store(mtu)
|
||||
}
|
||||
|
||||
func (e *endpoint) SetOnCloseAction(func()) {
|
||||
// No action needed on close
|
||||
}
|
||||
|
||||
type epID stack.TransportEndpointID
|
||||
|
||||
func (i epID) String() string {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gvisor.dev/gvisor/pkg/buffer"
|
||||
@@ -35,14 +36,16 @@ type Forwarder struct {
|
||||
logger *nblog.Logger
|
||||
flowLogger nftypes.FlowLogger
|
||||
// ruleIdMap is used to store the rule ID for a given connection
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
ruleIdMap sync.Map
|
||||
stack *stack.Stack
|
||||
endpoint *endpoint
|
||||
udpForwarder *udpForwarder
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
ip tcpip.Address
|
||||
netstack bool
|
||||
hasRawICMPAccess bool
|
||||
pingSemaphore chan struct{}
|
||||
}
|
||||
|
||||
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
|
||||
@@ -60,8 +63,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
endpoint := &endpoint{
|
||||
logger: logger,
|
||||
device: iface.GetWGDevice(),
|
||||
mtu: uint32(mtu),
|
||||
}
|
||||
endpoint.mtu.Store(uint32(mtu))
|
||||
|
||||
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||
return nil, fmt.Errorf("create NIC: %v", err)
|
||||
@@ -103,15 +106,16 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
f := &Forwarder{
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
logger: logger,
|
||||
flowLogger: flowLogger,
|
||||
stack: s,
|
||||
endpoint: endpoint,
|
||||
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
netstack: netstack,
|
||||
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||
pingSemaphore: make(chan struct{}, 3),
|
||||
}
|
||||
|
||||
receiveWindow := defaultReceiveWindow
|
||||
@@ -129,6 +133,8 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
||||
|
||||
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||
|
||||
f.checkICMPCapability()
|
||||
|
||||
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||
return f, nil
|
||||
}
|
||||
@@ -198,3 +204,24 @@ func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKe
|
||||
DstPort: dstPort,
|
||||
}
|
||||
}
|
||||
|
||||
// checkICMPCapability tests whether we have raw ICMP socket access at startup.
|
||||
func (f *Forwarder) checkICMPCapability() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.hasRawICMPAccess = false
|
||||
f.logger.Debug("forwarder: No raw ICMP socket access, will use ping binary fallback")
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP capability test socket: %v", err)
|
||||
}
|
||||
|
||||
f.hasRawICMPAccess = true
|
||||
f.logger.Debug("forwarder: Raw ICMP socket access available")
|
||||
}
|
||||
|
||||
@@ -2,8 +2,11 @@ package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -14,30 +17,95 @@ import (
|
||||
)
|
||||
|
||||
// handleICMP handles ICMP packets from the network stack
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
|
||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||
icmpType := uint8(icmpHdr.Type())
|
||||
icmpCode := uint8(icmpHdr.Code())
|
||||
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||
// dont process our own replies
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 0, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
// For Echo Requests, send and wait for response
|
||||
if icmpHdr.Type() == header.ICMPv4Echo {
|
||||
return f.handleICMPEcho(flowID, id, pkt, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc), forward without waiting
|
||||
if !f.hasRawICMPAccess {
|
||||
f.logger.Debug2("forwarder: Cannot handle ICMP type %v without raw socket access for %v", icmpHdr.Type(), epID(id))
|
||||
return false
|
||||
}
|
||||
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).AsSlice()
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, uint8(icmpHdr.Type()), uint8(icmpHdr.Code()), 100*time.Millisecond)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to forward ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// handleICMPEcho handles ICMP echo requests asynchronously with rate limiting.
|
||||
func (f *Forwarder) handleICMPEcho(flowID uuid.UUID, id stack.TransportEndpointID, pkt *stack.PacketBuffer, icmpType, icmpCode uint8) bool {
|
||||
select {
|
||||
case f.pingSemaphore <- struct{}{}:
|
||||
icmpData := stack.PayloadSince(pkt.TransportHeader()).ToSlice()
|
||||
rxBytes := pkt.Size()
|
||||
|
||||
go func() {
|
||||
defer func() { <-f.pingSemaphore }()
|
||||
|
||||
if f.hasRawICMPAccess {
|
||||
f.handleICMPViaSocket(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
} else {
|
||||
f.handleICMPViaPing(flowID, id, icmpType, icmpCode, icmpData, rxBytes)
|
||||
}
|
||||
}()
|
||||
default:
|
||||
f.logger.Debug3("forwarder: ICMP rate limit exceeded for %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// forwardICMPPacket creates a raw ICMP socket and sends the packet, returning the connection.
|
||||
// The caller is responsible for closing the returned connection.
|
||||
func (f *Forwarder) forwardICMPPacket(id stack.TransportEndpointID, payload []byte, icmpType, icmpCode uint8, timeout time.Duration) (net.PacketConn, error) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
lc := net.ListenConfig{}
|
||||
// TODO: support non-root
|
||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||
return nil, fmt.Errorf("create ICMP socket: %w", err)
|
||||
}
|
||||
|
||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||
return false
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
if closeErr := conn.Close(); closeErr != nil {
|
||||
f.logger.Debug1("forwarder: Failed to close ICMP socket: %v", closeErr)
|
||||
}
|
||||
return nil, fmt.Errorf("write ICMP packet: %w", err)
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// handleICMPViaSocket handles ICMP echo requests using raw sockets.
|
||||
func (f *Forwarder) handleICMPViaSocket(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
sendTime := time.Now()
|
||||
|
||||
conn, err := f.forwardICMPPacket(id, icmpData, icmpType, icmpCode, 5*time.Second)
|
||||
if err != nil {
|
||||
f.logger.Error2("forwarder: Failed to send ICMP packet for %v: %v", epID(id), err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
@@ -45,38 +113,22 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
||||
}
|
||||
}()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
dst := &net.IPAddr{IP: dstIP}
|
||||
txBytes := f.handleEchoResponse(conn, id)
|
||||
rtt := time.Since(sendTime).Round(10 * time.Microsecond)
|
||||
|
||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||
payload := fullPacket.AsSlice()
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, raw socket)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||
f.logger.Error2("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||
return true
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
// For Echo Requests, send and handle response
|
||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||
rxBytes := pkt.Size()
|
||||
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||
return true
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
response := make([]byte, f.endpoint.mtu)
|
||||
response := make([]byte, f.endpoint.mtu.Load())
|
||||
n, _, err := conn.ReadFrom(response)
|
||||
if err != nil {
|
||||
if !isTimeout(err) {
|
||||
@@ -85,31 +137,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
||||
return 0
|
||||
}
|
||||
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, response[:n]...)
|
||||
|
||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to inject ICMP response: %v", err)
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||
|
||||
return len(fullPacket)
|
||||
return f.injectICMPReply(id, response[:n])
|
||||
}
|
||||
|
||||
// sendICMPEvent stores flow events for ICMP packets
|
||||
@@ -152,3 +180,95 @@ func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.T
|
||||
|
||||
f.flowLogger.StoreEvent(fields)
|
||||
}
|
||||
|
||||
// handleICMPViaPing handles ICMP echo requests by executing the system ping binary.
|
||||
// This is used as a fallback when raw socket access is not available.
|
||||
func (f *Forwarder) handleICMPViaPing(flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, icmpData []byte, rxBytes int) {
|
||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||
cmd := buildPingCommand(ctx, dstIP, 5*time.Second)
|
||||
|
||||
pingStart := time.Now()
|
||||
if err := cmd.Run(); err != nil {
|
||||
f.logger.Warn4("forwarder: Ping binary failed for %v type %v code %v: %v", epID(id),
|
||||
icmpType, icmpCode, err)
|
||||
return
|
||||
}
|
||||
rtt := time.Since(pingStart).Round(10 * time.Microsecond)
|
||||
|
||||
f.logger.Trace3("forwarder: Forwarded ICMP echo request %v type %v code %v",
|
||||
epID(id), icmpType, icmpCode)
|
||||
|
||||
txBytes := f.synthesizeEchoReply(id, icmpData)
|
||||
|
||||
f.logger.Trace4("forwarder: Forwarded ICMP echo reply %v type %v code %v (rtt=%v, ping binary)",
|
||||
epID(id), icmpType, icmpCode, rtt)
|
||||
|
||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||
}
|
||||
|
||||
// buildPingCommand creates a platform-specific ping command.
|
||||
func buildPingCommand(ctx context.Context, target net.IP, timeout time.Duration) *exec.Cmd {
|
||||
timeoutSec := int(timeout.Seconds())
|
||||
if timeoutSec < 1 {
|
||||
timeoutSec = 1
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux", "android":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-W", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "darwin", "ios":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), "-q", target.String())
|
||||
case "freebsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-t", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "openbsd", "netbsd":
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", "-w", fmt.Sprintf("%d", timeoutSec), target.String())
|
||||
case "windows":
|
||||
return exec.CommandContext(ctx, "ping", "-n", "1", "-w", fmt.Sprintf("%d", timeoutSec*1000), target.String())
|
||||
default:
|
||||
return exec.CommandContext(ctx, "ping", "-c", "1", target.String())
|
||||
}
|
||||
}
|
||||
|
||||
// synthesizeEchoReply creates an ICMP echo reply from raw ICMP data and injects it back into the network stack.
|
||||
// Returns the size of the injected packet.
|
||||
func (f *Forwarder) synthesizeEchoReply(id stack.TransportEndpointID, icmpData []byte) int {
|
||||
replyICMP := make([]byte, len(icmpData))
|
||||
copy(replyICMP, icmpData)
|
||||
|
||||
replyICMPHdr := header.ICMPv4(replyICMP)
|
||||
replyICMPHdr.SetType(header.ICMPv4EchoReply)
|
||||
replyICMPHdr.SetChecksum(0)
|
||||
replyICMPHdr.SetChecksum(header.ICMPv4Checksum(replyICMPHdr, 0))
|
||||
|
||||
return f.injectICMPReply(id, replyICMP)
|
||||
}
|
||||
|
||||
// injectICMPReply wraps an ICMP payload in an IP header and injects it into the network stack.
|
||||
// Returns the total size of the injected packet, or 0 if injection failed.
|
||||
func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []byte) int {
|
||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||
ip := header.IPv4(ipHdr)
|
||||
ip.Encode(&header.IPv4Fields{
|
||||
TotalLength: uint16(header.IPv4MinimumSize + len(icmpPayload)),
|
||||
TTL: 64,
|
||||
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||
SrcAddr: id.LocalAddress,
|
||||
DstAddr: id.RemoteAddress,
|
||||
})
|
||||
ip.SetChecksum(^ip.CalculateChecksum())
|
||||
|
||||
fullPacket := make([]byte, 0, len(ipHdr)+len(icmpPayload))
|
||||
fullPacket = append(fullPacket, ipHdr...)
|
||||
fullPacket = append(fullPacket, icmpPayload...)
|
||||
|
||||
// Bypass netstack and send directly to peer to avoid looping through our ICMP handler
|
||||
if err := f.endpoint.device.CreateOutboundPacket(fullPacket, id.RemoteAddress.AsSlice()); err != nil {
|
||||
f.logger.Error1("forwarder: Failed to send ICMP reply to peer: %v", err)
|
||||
return 0
|
||||
}
|
||||
|
||||
return len(fullPacket)
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
@@ -131,10 +132,10 @@ func (f *udpForwarder) cleanup() {
|
||||
}
|
||||
|
||||
// handleUDP is called by the UDP forwarder for new packets
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) bool {
|
||||
if f.ctx.Err() != nil {
|
||||
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
id := r.ID()
|
||||
@@ -144,7 +145,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.udpForwarder.RUnlock()
|
||||
if exists {
|
||||
f.logger.Trace1("forwarder: existing UDP connection for %v", epID(id))
|
||||
return
|
||||
return true
|
||||
}
|
||||
|
||||
flowID := uuid.New()
|
||||
@@ -162,7 +163,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err != nil {
|
||||
f.logger.Debug2("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||
// TODO: Send ICMP error message
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
// Create wait queue for blocking syscalls
|
||||
@@ -173,10 +174,10 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||
inConn := gonet.NewUDPConn(&wq, ep)
|
||||
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||
|
||||
pConn := &udpPacketConn{
|
||||
@@ -199,7 +200,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
if err := outConn.Close(); err != nil {
|
||||
f.logger.Debug2("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||
}
|
||||
return
|
||||
return true
|
||||
}
|
||||
f.udpForwarder.conns[id] = pConn
|
||||
f.udpForwarder.Unlock()
|
||||
@@ -208,6 +209,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||
f.logger.Trace1("forwarder: established UDP connection %v", epID(id))
|
||||
|
||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||
@@ -348,7 +350,7 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
||||
}
|
||||
|
||||
func isClosedError(err error) bool {
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) || errors.Is(err, io.EOF)
|
||||
}
|
||||
|
||||
func isTimeout(err error) bool {
|
||||
|
||||
@@ -168,6 +168,15 @@ func (l *Logger) Warn3(format string, arg1, arg2, arg3 any) {
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Warn4(format string, arg1, arg2, arg3, arg4 any) {
|
||||
if l.level.Load() >= uint32(LevelWarn) {
|
||||
select {
|
||||
case l.msgChannel <- logMessage{level: LevelWarn, format: format, arg1: arg1, arg2: arg2, arg3: arg3, arg4: arg4}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Debug1(format string, arg1 any) {
|
||||
if l.level.Load() >= uint32(LevelDebug) {
|
||||
select {
|
||||
|
||||
@@ -27,8 +27,23 @@ type receiverCreator struct {
|
||||
iceBind *ICEBind
|
||||
}
|
||||
|
||||
func (rc receiverCreator) CreateIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
return rc.iceBind.createIPv4ReceiverFn(pc, conn, rxOffload, msgPool)
|
||||
func (rc receiverCreator) CreateReceiverFn(pc wgConn.BatchReader, conn *net.UDPConn, rxOffload bool, msgPool *sync.Pool) wgConn.ReceiveFunc {
|
||||
if ipv4PC, ok := pc.(*ipv4.PacketConn); ok {
|
||||
return rc.iceBind.createIPv4ReceiverFn(ipv4PC, conn, rxOffload, msgPool)
|
||||
}
|
||||
// IPv6 is currently not supported in the udpmux, this is a stub for compatibility with the
|
||||
// wireguard-go ReceiverCreator interface which is called for both IPv4 and IPv6.
|
||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||
buf := bufs[0]
|
||||
size, ep, err := conn.ReadFromUDPAddrPort(buf)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sizes[0] = size
|
||||
stdEp := &wgConn.StdNetEndpoint{AddrPort: ep}
|
||||
eps[0] = stdEp
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ICEBind is a bind implementation with two main features:
|
||||
|
||||
Reference in New Issue
Block a user