Merge pull request #208 from fosrl/icmp2

Support ICMP test requests for clients
This commit is contained in:
Owen Schwartz
2025-12-16 17:19:22 -05:00
committed by GitHub
5 changed files with 463 additions and 32 deletions

View File

@@ -20,7 +20,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt
FROM alpine:3.23 AS runner FROM alpine:3.23 AS runner
RUN apk --no-cache add ca-certificates tzdata RUN apk --no-cache add ca-certificates tzdata iputils
COPY --from=builder /newt /usr/local/bin/ COPY --from=builder /newt /usr/local/bin/
COPY entrypoint.sh / COPY entrypoint.sh /

View File

@@ -40,12 +40,13 @@ type Target struct {
SourcePrefix string `json:"sourcePrefix"` SourcePrefix string `json:"sourcePrefix"`
DestPrefix string `json:"destPrefix"` DestPrefix string `json:"destPrefix"`
RewriteTo string `json:"rewriteTo,omitempty"` RewriteTo string `json:"rewriteTo,omitempty"`
DisableIcmp bool `json:"disableIcmp,omitempty"`
PortRange []PortRange `json:"portRange,omitempty"` PortRange []PortRange `json:"portRange,omitempty"`
} }
type PortRange struct { type PortRange struct {
Min uint16 `json:"min"` Min uint16 `json:"min"`
Max uint16 `json:"max"` Max uint16 `json:"max"`
Protocol string `json:"protocol"` // "tcp" or "udp" Protocol string `json:"protocol"` // "tcp" or "udp"
} }
@@ -593,8 +594,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error {
s.dns, s.dns,
s.mtu, s.mtu,
netstack2.NetTunOptions{ netstack2.NetTunOptions{
EnableTCPProxy: true, EnableTCPProxy: true,
EnableUDPProxy: true, EnableUDPProxy: true,
EnableICMPProxy: true,
}, },
) )
if err != nil { if err != nil {
@@ -700,13 +702,13 @@ func (s *WireGuardService) ensureTargets(targets []Target) error {
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min, Min: pr.Min,
Max: pr.Max, Max: pr.Max,
Protocol: pr.Protocol, Protocol: pr.Protocol,
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
} }
@@ -1094,10 +1096,11 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min, Min: pr.Min,
Max: pr.Max, Max: pr.Max,
Protocol: pr.Protocol,
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
} }
@@ -1209,12 +1212,13 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) {
var portRanges []netstack2.PortRange var portRanges []netstack2.PortRange
for _, pr := range target.PortRange { for _, pr := range target.PortRange {
portRanges = append(portRanges, netstack2.PortRange{ portRanges = append(portRanges, netstack2.PortRange{
Min: pr.Min, Min: pr.Min,
Max: pr.Max, Max: pr.Max,
Protocol: pr.Protocol,
}) })
} }
s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp)
logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange)
} }
} }

View File

@@ -10,12 +10,18 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"net/netip"
"os/exec"
"sync" "sync"
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/checksum"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -58,6 +64,9 @@ const (
// Buffer size for copying data // Buffer size for copying data
bufferSize = 32 * 1024 bufferSize = 32 * 1024
// icmpTimeout is the default timeout for ICMP ping requests.
icmpTimeout = 5 * time.Second
) )
// TCPHandler handles TCP connections from netstack // TCPHandler handles TCP connections from netstack
@@ -72,6 +81,12 @@ type UDPHandler struct {
proxyHandler *ProxyHandler proxyHandler *ProxyHandler
} }
// ICMPHandler handles ICMP packets from netstack
type ICMPHandler struct {
stack *stack.Stack
proxyHandler *ProxyHandler
}
// NewTCPHandler creates a new TCP handler // NewTCPHandler creates a new TCP handler
func NewTCPHandler(s *stack.Stack, ph *ProxyHandler) *TCPHandler { func NewTCPHandler(s *stack.Stack, ph *ProxyHandler) *TCPHandler {
return &TCPHandler{stack: s, proxyHandler: ph} return &TCPHandler{stack: s, proxyHandler: ph}
@@ -82,6 +97,11 @@ func NewUDPHandler(s *stack.Stack, ph *ProxyHandler) *UDPHandler {
return &UDPHandler{stack: s, proxyHandler: ph} return &UDPHandler{stack: s, proxyHandler: ph}
} }
// NewICMPHandler creates a new ICMP handler
func NewICMPHandler(s *stack.Stack, ph *ProxyHandler) *ICMPHandler {
return &ICMPHandler{stack: s, proxyHandler: ph}
}
// InstallTCPHandler installs the TCP forwarder on the stack // InstallTCPHandler installs the TCP forwarder on the stack
func (h *TCPHandler) InstallTCPHandler() error { func (h *TCPHandler) InstallTCPHandler() error {
tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) {
@@ -348,3 +368,334 @@ func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration)
dst.SetReadDeadline(time.Now().Add(timeout)) dst.SetReadDeadline(time.Now().Add(timeout))
} }
} }
// InstallICMPHandler installs the ICMP handler on the stack
func (h *ICMPHandler) InstallICMPHandler() error {
h.stack.SetTransportProtocolHandler(header.ICMPv4ProtocolNumber, h.handleICMPPacket)
logger.Info("ICMP Handler: Installed ICMP protocol handler")
return nil
}
// handleICMPPacket handles incoming ICMP packets
func (h *ICMPHandler) handleICMPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool {
logger.Debug("ICMP Handler: Received ICMP packet from %s to %s", id.RemoteAddress, id.LocalAddress)
// Get the ICMP header from the packet
icmpData := pkt.TransportHeader().Slice()
if len(icmpData) < header.ICMPv4MinimumSize {
logger.Debug("ICMP Handler: Packet too small for ICMP header: %d bytes", len(icmpData))
return false
}
icmpHdr := header.ICMPv4(icmpData)
icmpType := icmpHdr.Type()
icmpCode := icmpHdr.Code()
logger.Debug("ICMP Handler: Type=%d, Code=%d, Ident=%d, Seq=%d",
icmpType, icmpCode, icmpHdr.Ident(), icmpHdr.Sequence())
// Only handle Echo Request (ping)
if icmpType != header.ICMPv4Echo {
logger.Debug("ICMP Handler: Ignoring non-echo ICMP type: %d", icmpType)
return false
}
// Extract source and destination addresses
srcIP := id.RemoteAddress.String()
dstIP := id.LocalAddress.String()
logger.Info("ICMP Handler: Echo Request from %s to %s (ident=%d, seq=%d)",
srcIP, dstIP, icmpHdr.Ident(), icmpHdr.Sequence())
// Convert to netip.Addr for subnet matching
srcAddr, err := netip.ParseAddr(srcIP)
if err != nil {
logger.Debug("ICMP Handler: Failed to parse source IP %s: %v", srcIP, err)
return false
}
dstAddr, err := netip.ParseAddr(dstIP)
if err != nil {
logger.Debug("ICMP Handler: Failed to parse dest IP %s: %v", dstIP, err)
return false
}
// Check subnet rules (use port 0 for ICMP since it doesn't have ports)
if h.proxyHandler == nil {
logger.Debug("ICMP Handler: No proxy handler configured")
return false
}
matchedRule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, 0, header.ICMPv4ProtocolNumber)
if matchedRule == nil {
logger.Debug("ICMP Handler: No matching subnet rule for %s -> %s", srcIP, dstIP)
return false
}
logger.Info("ICMP Handler: Matched subnet rule for %s -> %s", srcIP, dstIP)
// Determine actual destination (with possible rewrite)
actualDstIP := dstIP
if matchedRule.RewriteTo != "" {
resolvedAddr, err := h.proxyHandler.resolveRewriteAddress(matchedRule.RewriteTo)
if err != nil {
logger.Info("ICMP Handler: Failed to resolve rewrite address %s: %v", matchedRule.RewriteTo, err)
} else {
actualDstIP = resolvedAddr.String()
logger.Info("ICMP Handler: Using rewritten destination %s (original: %s)", actualDstIP, dstIP)
}
}
// Get the full ICMP payload (including the data after the header)
icmpPayload := pkt.Data().AsRange().ToSlice()
// Handle the ping in a goroutine to avoid blocking
go h.proxyPing(srcIP, dstIP, actualDstIP, icmpHdr.Ident(), icmpHdr.Sequence(), icmpPayload)
return true
}
// proxyPing sends a ping to the actual destination and injects the reply back
func (h *ICMPHandler) proxyPing(srcIP, originalDstIP, actualDstIP string, ident, seq uint16, payload []byte) {
logger.Debug("ICMP Handler: Proxying ping from %s to %s (actual: %s), ident=%d, seq=%d",
srcIP, originalDstIP, actualDstIP, ident, seq)
// Try three methods in order: ip4:icmp -> udp4 -> ping command
// Track which method succeeded so we can handle identifier matching correctly
method, success := h.tryICMPMethods(actualDstIP, ident, seq, payload)
if !success {
logger.Info("ICMP Handler: All ping methods failed for %s", actualDstIP)
return
}
logger.Info("ICMP Handler: Ping successful to %s using %s, injecting reply (ident=%d, seq=%d)",
actualDstIP, method, ident, seq)
// Build the reply packet to inject back into the netstack
// The reply should appear to come from the original destination (before rewrite)
h.injectICMPReply(srcIP, originalDstIP, ident, seq, payload)
}
// tryICMPMethods tries all available ICMP methods in order
func (h *ICMPHandler) tryICMPMethods(actualDstIP string, ident, seq uint16, payload []byte) (string, bool) {
if h.tryRawICMP(actualDstIP, ident, seq, payload, false) {
return "raw ICMP", true
}
if h.tryUnprivilegedICMP(actualDstIP, ident, seq, payload) {
return "unprivileged ICMP", true
}
if h.tryPingCommand(actualDstIP, ident, seq, payload) {
return "ping command", true
}
return "", false
}
// tryRawICMP attempts to ping using raw ICMP sockets (requires CAP_NET_RAW or root)
func (h *ICMPHandler) tryRawICMP(actualDstIP string, ident, seq uint16, payload []byte, ignoreIdent bool) bool {
conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0")
if err != nil {
logger.Debug("ICMP Handler: Raw ICMP socket not available: %v", err)
return false
}
defer conn.Close()
logger.Debug("ICMP Handler: Using raw ICMP socket")
return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, false, ignoreIdent)
}
// tryUnprivilegedICMP attempts to ping using unprivileged ICMP (requires ping_group_range configured)
func (h *ICMPHandler) tryUnprivilegedICMP(actualDstIP string, ident, seq uint16, payload []byte) bool {
conn, err := icmp.ListenPacket("udp4", "0.0.0.0")
if err != nil {
logger.Debug("ICMP Handler: Unprivileged ICMP socket not available: %v", err)
return false
}
defer conn.Close()
logger.Debug("ICMP Handler: Using unprivileged ICMP socket")
// Unprivileged ICMP doesn't let us control the identifier, so we ignore it in matching
return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, true, true)
}
// sendAndReceiveICMP sends an ICMP echo request and waits for the reply
func (h *ICMPHandler) sendAndReceiveICMP(conn *icmp.PacketConn, actualDstIP string, ident, seq uint16, payload []byte, isUnprivileged bool, ignoreIdent bool) bool {
// Build the ICMP echo request message
echoMsg := &icmp.Message{
Type: ipv4.ICMPTypeEcho,
Code: 0,
Body: &icmp.Echo{
ID: int(ident),
Seq: int(seq),
Data: payload,
},
}
msgBytes, err := echoMsg.Marshal(nil)
if err != nil {
logger.Debug("ICMP Handler: Failed to marshal ICMP message: %v", err)
return false
}
// Resolve destination address based on socket type
var writeErr error
if isUnprivileged {
// For unprivileged ICMP, use UDP-style addressing
udpAddr := &net.UDPAddr{IP: net.ParseIP(actualDstIP)}
logger.Debug("ICMP Handler: Sending ping to %s (unprivileged)", udpAddr.String())
conn.SetDeadline(time.Now().Add(icmpTimeout))
_, writeErr = conn.WriteTo(msgBytes, udpAddr)
} else {
// For raw ICMP, use IP addressing
dst, err := net.ResolveIPAddr("ip4", actualDstIP)
if err != nil {
logger.Debug("ICMP Handler: Failed to resolve destination %s: %v", actualDstIP, err)
return false
}
logger.Debug("ICMP Handler: Sending ping to %s (raw)", dst.String())
conn.SetDeadline(time.Now().Add(icmpTimeout))
_, writeErr = conn.WriteTo(msgBytes, dst)
}
if writeErr != nil {
logger.Debug("ICMP Handler: Failed to send ping to %s: %v", actualDstIP, writeErr)
return false
}
logger.Debug("ICMP Handler: Ping sent to %s, waiting for reply (ident=%d, seq=%d)", actualDstIP, ident, seq)
// Wait for reply - loop to filter out non-matching packets
replyBuf := make([]byte, 1500)
for {
n, peer, err := conn.ReadFrom(replyBuf)
if err != nil {
logger.Debug("ICMP Handler: Failed to receive ping reply from %s: %v", actualDstIP, err)
return false
}
logger.Debug("ICMP Handler: Received %d bytes from %s", n, peer.String())
// Parse the reply
replyMsg, err := icmp.ParseMessage(1, replyBuf[:n])
if err != nil {
logger.Debug("ICMP Handler: Failed to parse ICMP message: %v", err)
continue
}
// Check if it's an echo reply (type 0), not an echo request (type 8)
if replyMsg.Type != ipv4.ICMPTypeEchoReply {
logger.Debug("ICMP Handler: Received non-echo-reply type: %v, continuing to wait", replyMsg.Type)
continue
}
reply, ok := replyMsg.Body.(*icmp.Echo)
if !ok {
logger.Debug("ICMP Handler: Invalid echo reply body type, continuing to wait")
continue
}
// Verify the sequence matches what we sent
// For unprivileged ICMP, the kernel controls the identifier, so we only check sequence
if reply.Seq != int(seq) {
logger.Debug("ICMP Handler: Reply seq mismatch: got seq=%d, want seq=%d", reply.Seq, seq)
continue
}
if !ignoreIdent && reply.ID != int(ident) {
logger.Debug("ICMP Handler: Reply ident mismatch: got ident=%d, want ident=%d", reply.ID, ident)
continue
}
// Found matching reply
logger.Debug("ICMP Handler: Received valid echo reply")
return true
}
}
// tryPingCommand attempts to ping using the system ping command (always works, but less control)
func (h *ICMPHandler) tryPingCommand(actualDstIP string, ident, seq uint16, payload []byte) bool {
logger.Debug("ICMP Handler: Attempting to use system ping command")
ctx, cancel := context.WithTimeout(context.Background(), icmpTimeout)
defer cancel()
// Send one ping with timeout
// -c 1: count = 1 packet
// -W 5: timeout = 5 seconds
// -q: quiet output (just summary)
cmd := exec.CommandContext(ctx, "ping", "-c", "1", "-W", "5", "-q", actualDstIP)
output, err := cmd.CombinedOutput()
if err != nil {
logger.Debug("ICMP Handler: System ping command failed: %v, output: %s", err, string(output))
return false
}
logger.Debug("ICMP Handler: System ping command succeeded")
return true
}
// injectICMPReply creates an ICMP echo reply packet and queues it to be sent back through the tunnel
func (h *ICMPHandler) injectICMPReply(dstIP, srcIP string, ident, seq uint16, payload []byte) {
logger.Debug("ICMP Handler: Creating reply from %s to %s (ident=%d, seq=%d)",
srcIP, dstIP, ident, seq)
// Parse addresses
srcAddr, err := netip.ParseAddr(srcIP)
if err != nil {
logger.Info("ICMP Handler: Failed to parse source IP for reply: %v", err)
return
}
dstAddr, err := netip.ParseAddr(dstIP)
if err != nil {
logger.Info("ICMP Handler: Failed to parse dest IP for reply: %v", err)
return
}
// Calculate total packet size
ipHeaderLen := header.IPv4MinimumSize
icmpHeaderLen := header.ICMPv4MinimumSize
totalLen := ipHeaderLen + icmpHeaderLen + len(payload)
// Create the packet buffer
pkt := make([]byte, totalLen)
// Build IPv4 header
ipHdr := header.IPv4(pkt[:ipHeaderLen])
ipHdr.Encode(&header.IPv4Fields{
TotalLength: uint16(totalLen),
TTL: 64,
Protocol: uint8(header.ICMPv4ProtocolNumber),
SrcAddr: tcpip.AddrFrom4(srcAddr.As4()),
DstAddr: tcpip.AddrFrom4(dstAddr.As4()),
})
ipHdr.SetChecksum(^ipHdr.CalculateChecksum())
// Build ICMP header
icmpHdr := header.ICMPv4(pkt[ipHeaderLen : ipHeaderLen+icmpHeaderLen])
icmpHdr.SetType(header.ICMPv4EchoReply)
icmpHdr.SetCode(0)
icmpHdr.SetIdent(ident)
icmpHdr.SetSequence(seq)
// Copy payload
copy(pkt[ipHeaderLen+icmpHeaderLen:], payload)
// Calculate ICMP checksum (covers ICMP header + payload)
icmpHdr.SetChecksum(0)
icmpData := pkt[ipHeaderLen:]
icmpHdr.SetChecksum(^checksum.Checksum(icmpData, 0))
logger.Debug("ICMP Handler: Built reply packet, total length=%d", totalLen)
// Queue the packet to be sent back through the tunnel
if h.proxyHandler != nil {
if h.proxyHandler.QueueICMPReply(pkt) {
logger.Info("ICMP Handler: Queued echo reply packet for transmission")
} else {
logger.Info("ICMP Handler: Failed to queue echo reply packet")
}
} else {
logger.Info("ICMP Handler: Cannot queue reply - proxy handler not available")
}
}

View File

@@ -43,6 +43,7 @@ type PortRange struct {
type SubnetRule struct { type SubnetRule struct {
SourcePrefix netip.Prefix // Source IP prefix (who is sending) SourcePrefix netip.Prefix // Source IP prefix (who is sending)
DestPrefix netip.Prefix // Destination IP prefix (where it's going) DestPrefix netip.Prefix // Destination IP prefix (where it's going)
DisableIcmp bool // If true, ICMP traffic is blocked for this subnet
RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name
PortRanges []PortRange // empty slice means all ports allowed PortRanges []PortRange // empty slice means all ports allowed
} }
@@ -69,7 +70,7 @@ func NewSubnetLookup() *SubnetLookup {
// AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
sl.mu.Lock() sl.mu.Lock()
defer sl.mu.Unlock() defer sl.mu.Unlock()
@@ -81,6 +82,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite
sl.rules[key] = &SubnetRule{ sl.rules[key] = &SubnetRule{
SourcePrefix: sourcePrefix, SourcePrefix: sourcePrefix,
DestPrefix: destPrefix, DestPrefix: destPrefix,
DisableIcmp: disableIcmp,
RewriteTo: rewriteTo, RewriteTo: rewriteTo,
PortRanges: portRanges, PortRanges: portRanges,
} }
@@ -123,6 +125,11 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.
continue continue
} }
if rule.DisableIcmp && (proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber) {
// ICMP is disabled for this subnet
return nil
}
// Both IPs match - now check port restrictions // Both IPs match - now check port restrictions
// If no port ranges specified, all ports are allowed // If no port ranges specified, all ports are allowed
if len(rule.PortRanges) == 0 { if len(rule.PortRanges) == 0 {
@@ -180,23 +187,27 @@ type ProxyHandler struct {
proxyNotifyHandle *channel.NotificationHandle proxyNotifyHandle *channel.NotificationHandle
tcpHandler *TCPHandler tcpHandler *TCPHandler
udpHandler *UDPHandler udpHandler *UDPHandler
icmpHandler *ICMPHandler
subnetLookup *SubnetLookup subnetLookup *SubnetLookup
natTable map[connKey]*natState natTable map[connKey]*natState
destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups
natMu sync.RWMutex natMu sync.RWMutex
enabled bool enabled bool
icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel
notifiable channel.Notification // Notification handler for triggering reads
} }
// ProxyHandlerOptions configures the proxy handler // ProxyHandlerOptions configures the proxy handler
type ProxyHandlerOptions struct { type ProxyHandlerOptions struct {
EnableTCP bool EnableTCP bool
EnableUDP bool EnableUDP bool
MTU int EnableICMP bool
MTU int
} }
// NewProxyHandler creates a new proxy handler for promiscuous mode // NewProxyHandler creates a new proxy handler for promiscuous mode
func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
if !options.EnableTCP && !options.EnableUDP { if !options.EnableTCP && !options.EnableUDP && !options.EnableICMP {
return nil, nil // No proxy needed return nil, nil // No proxy needed
} }
@@ -205,6 +216,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
subnetLookup: NewSubnetLookup(), subnetLookup: NewSubnetLookup(),
natTable: make(map[connKey]*natState), natTable: make(map[connKey]*natState),
destRewriteTable: make(map[destKey]netip.Addr), destRewriteTable: make(map[destKey]netip.Addr),
icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets
proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyEp: channel.New(1024, uint32(options.MTU), ""),
proxyStack: stack.New(stack.Options{ proxyStack: stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ NetworkProtocols: []stack.NetworkProtocolFactory{
@@ -236,6 +248,15 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
} }
} }
// Initialize ICMP handler if enabled
if options.EnableICMP {
handler.icmpHandler = NewICMPHandler(handler.proxyStack, handler)
if err := handler.icmpHandler.InstallICMPHandler(); err != nil {
return nil, fmt.Errorf("failed to install ICMP handler: %v", err)
}
logger.Info("ProxyHandler: ICMP handler enabled")
}
// // Example 1: Add a rule with no port restrictions (all ports allowed) // // Example 1: Add a rule with no port restrictions (all ports allowed)
// // This accepts all traffic FROM 10.0.0.0/24 TO 10.20.20.0/24 // // This accepts all traffic FROM 10.0.0.0/24 TO 10.20.20.0/24
// sourceSubnet := netip.MustParsePrefix("10.0.0.0/24") // sourceSubnet := netip.MustParsePrefix("10.0.0.0/24")
@@ -260,11 +281,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) {
// destPrefix: The IP prefix of the destination // destPrefix: The IP prefix of the destination
// rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name // rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
if p == nil || !p.enabled { if p == nil || !p.enabled {
return return
} }
p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges) p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
} }
// RemoveSubnetRule removes a subnet from the proxy handler // RemoveSubnetRule removes a subnet from the proxy handler
@@ -343,6 +364,9 @@ func (p *ProxyHandler) Initialize(notifiable channel.Notification) error {
return nil return nil
} }
// Store notifiable for triggering notifications on ICMP replies
p.notifiable = notifiable
// Add notification handler // Add notification handler
p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable) p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable)
@@ -421,14 +445,21 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
} }
udpHeader := header.UDP(packet[headerLen:]) udpHeader := header.UDP(packet[headerLen:])
dstPort = udpHeader.DestinationPort() dstPort = udpHeader.DestinationPort()
default: case header.ICMPv4ProtocolNumber:
// For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions) // ICMP doesn't have ports, use port 0 (must match rules with no port restrictions)
dstPort = 0 dstPort = 0
logger.Debug("HandleIncomingPacket: ICMP packet from %s to %s", srcAddr, dstAddr)
default:
// For other protocols, use port 0 (must match rules with no port restrictions)
dstPort = 0
logger.Debug("HandleIncomingPacket: Unknown protocol %d from %s to %s", protocol, srcAddr, dstAddr)
} }
// Check if the source IP, destination IP, port, and protocol match any subnet rule // Check if the source IP, destination IP, port, and protocol match any subnet rule
matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol)
if matchedRule != nil { if matchedRule != nil {
logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)",
srcAddr, dstAddr, protocol, dstPort)
// Check if we need to perform DNAT // Check if we need to perform DNAT
if matchedRule.RewriteTo != "" { if matchedRule.RewriteTo != "" {
// Create connection tracking key using original destination // Create connection tracking key using original destination
@@ -515,9 +546,12 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool {
Payload: buffer.MakeWithData(packet), Payload: buffer.MakeWithData(packet),
}) })
p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb)
logger.Debug("HandleIncomingPacket: Injected packet into proxy stack (proto=%d)", protocol)
return true return true
} }
logger.Debug("HandleIncomingPacket: No matching rule for %s -> %s (proto=%d, port=%d)",
srcAddr, dstAddr, protocol, dstPort)
return false return false
} }
@@ -640,6 +674,15 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
return nil return nil
} }
// First check for ICMP reply packets (non-blocking)
select {
case icmpReply := <-p.icmpReplies:
logger.Debug("ReadOutgoingPacket: Returning ICMP reply packet (%d bytes)", len(icmpReply))
return buffer.NewViewWithData(icmpReply)
default:
// No ICMP reply available, continue to check proxy endpoint
}
pkt := p.proxyEp.Read() pkt := p.proxyEp.Read()
if pkt != nil { if pkt != nil {
view := pkt.ToView() view := pkt.ToView()
@@ -669,6 +712,11 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
srcPort = udpHeader.SourcePort() srcPort = udpHeader.SourcePort()
dstPort = udpHeader.DestinationPort() dstPort = udpHeader.DestinationPort()
} }
case header.ICMPv4ProtocolNumber:
// ICMP packets don't need NAT translation in our implementation
// since we construct reply packets with the correct addresses
logger.Debug("ReadOutgoingPacket: ICMP packet from %s to %s", srcIP, dstIP)
return view
} }
// Look up NAT state for reverse translation // Look up NAT state for reverse translation
@@ -702,12 +750,37 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View {
return nil return nil
} }
// QueueICMPReply queues an ICMP reply packet to be sent back through the tunnel
func (p *ProxyHandler) QueueICMPReply(packet []byte) bool {
if p == nil || !p.enabled {
return false
}
select {
case p.icmpReplies <- packet:
logger.Debug("QueueICMPReply: Queued ICMP reply packet (%d bytes)", len(packet))
// Trigger notification so WriteNotify picks up the packet
if p.notifiable != nil {
p.notifiable.WriteNotify()
}
return true
default:
logger.Info("QueueICMPReply: ICMP reply channel full, dropping packet")
return false
}
}
// Close cleans up the proxy handler resources // Close cleans up the proxy handler resources
func (p *ProxyHandler) Close() error { func (p *ProxyHandler) Close() error {
if p == nil || !p.enabled { if p == nil || !p.enabled {
return nil return nil
} }
// Close ICMP replies channel
if p.icmpReplies != nil {
close(p.icmpReplies)
}
if p.proxyStack != nil { if p.proxyStack != nil {
p.proxyStack.RemoveNIC(1) p.proxyStack.RemoveNIC(1)
p.proxyStack.Close() p.proxyStack.Close()

View File

@@ -56,15 +56,17 @@ type Net netTun
// NetTunOptions contains options for creating a NetTUN device // NetTunOptions contains options for creating a NetTUN device
type NetTunOptions struct { type NetTunOptions struct {
EnableTCPProxy bool EnableTCPProxy bool
EnableUDPProxy bool EnableUDPProxy bool
EnableICMPProxy bool
} }
// CreateNetTUN creates a new TUN device with netstack without proxying // CreateNetTUN creates a new TUN device with netstack without proxying
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
return CreateNetTUNWithOptions(localAddresses, dnsServers, mtu, NetTunOptions{ return CreateNetTUNWithOptions(localAddresses, dnsServers, mtu, NetTunOptions{
EnableTCPProxy: true, EnableTCPProxy: true,
EnableUDPProxy: true, EnableUDPProxy: true,
EnableICMPProxy: true,
}) })
} }
@@ -84,13 +86,14 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o
mtu: mtu, mtu: mtu,
} }
// Initialize proxy handler if TCP or UDP proxying is enabled // Initialize proxy handler if TCP, UDP, or ICMP proxying is enabled
if options.EnableTCPProxy || options.EnableUDPProxy { if options.EnableTCPProxy || options.EnableUDPProxy || options.EnableICMPProxy {
var err error var err error
dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{ dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{
EnableTCP: options.EnableTCPProxy, EnableTCP: options.EnableTCPProxy,
EnableUDP: options.EnableUDPProxy, EnableUDP: options.EnableUDPProxy,
MTU: mtu, EnableICMP: options.EnableICMPProxy,
MTU: mtu,
}) })
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err) return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err)
@@ -351,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) {
// AddProxySubnetRule adds a subnet rule to the proxy handler // AddProxySubnetRule adds a subnet rule to the proxy handler
// If portRanges is nil or empty, all ports are allowed for this subnet // If portRanges is nil or empty, all ports are allowed for this subnet
// rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com")
func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) {
tun := (*netTun)(net) tun := (*netTun)(net)
if tun.proxyHandler != nil { if tun.proxyHandler != nil {
tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp)
} }
} }