diff --git a/Dockerfile b/Dockerfile index ad11376..197ac84 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt FROM alpine:3.23 AS runner -RUN apk --no-cache add ca-certificates tzdata +RUN apk --no-cache add ca-certificates tzdata iputils COPY --from=builder /newt /usr/local/bin/ COPY entrypoint.sh / diff --git a/clients/clients.go b/clients/clients.go index 7bc2669..7ef953f 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -40,12 +40,13 @@ type Target struct { SourcePrefix string `json:"sourcePrefix"` DestPrefix string `json:"destPrefix"` RewriteTo string `json:"rewriteTo,omitempty"` + DisableIcmp bool `json:"disableIcmp,omitempty"` PortRange []PortRange `json:"portRange,omitempty"` } type PortRange struct { - Min uint16 `json:"min"` - Max uint16 `json:"max"` + Min uint16 `json:"min"` + Max uint16 `json:"max"` Protocol string `json:"protocol"` // "tcp" or "udp" } @@ -593,8 +594,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.dns, s.mtu, netstack2.NetTunOptions{ - EnableTCPProxy: true, - EnableUDPProxy: true, + EnableTCPProxy: true, + EnableUDPProxy: true, + EnableICMPProxy: true, }, ) if err != nil { @@ -700,13 +702,13 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, + Min: pr.Min, + Max: pr.Max, Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } @@ -1094,10 +1096,11 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { portRanges = append(portRanges, netstack2.PortRange{ Min: pr.Min, Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } @@ -1209,12 +1212,13 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } } diff --git a/netstack2/handlers.go b/netstack2/handlers.go index bdc9feb..014d872 100644 --- a/netstack2/handlers.go +++ b/netstack2/handlers.go @@ -10,12 +10,18 @@ import ( "fmt" "io" "net" + "net/netip" + "os/exec" "sync" "time" "github.com/fosrl/newt/logger" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -58,6 +64,9 @@ const ( // Buffer size for copying data bufferSize = 32 * 1024 + + // icmpTimeout is the default timeout for ICMP ping requests. + icmpTimeout = 5 * time.Second ) // TCPHandler handles TCP connections from netstack @@ -72,6 +81,12 @@ type UDPHandler struct { proxyHandler *ProxyHandler } +// ICMPHandler handles ICMP packets from netstack +type ICMPHandler struct { + stack *stack.Stack + proxyHandler *ProxyHandler +} + // NewTCPHandler creates a new TCP handler func NewTCPHandler(s *stack.Stack, ph *ProxyHandler) *TCPHandler { return &TCPHandler{stack: s, proxyHandler: ph} @@ -82,6 +97,11 @@ func NewUDPHandler(s *stack.Stack, ph *ProxyHandler) *UDPHandler { return &UDPHandler{stack: s, proxyHandler: ph} } +// NewICMPHandler creates a new ICMP handler +func NewICMPHandler(s *stack.Stack, ph *ProxyHandler) *ICMPHandler { + return &ICMPHandler{stack: s, proxyHandler: ph} +} + // InstallTCPHandler installs the TCP forwarder on the stack func (h *TCPHandler) InstallTCPHandler() error { tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { @@ -348,3 +368,334 @@ func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) dst.SetReadDeadline(time.Now().Add(timeout)) } } + +// InstallICMPHandler installs the ICMP handler on the stack +func (h *ICMPHandler) InstallICMPHandler() error { + h.stack.SetTransportProtocolHandler(header.ICMPv4ProtocolNumber, h.handleICMPPacket) + logger.Info("ICMP Handler: Installed ICMP protocol handler") + return nil +} + +// handleICMPPacket handles incoming ICMP packets +func (h *ICMPHandler) handleICMPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + logger.Debug("ICMP Handler: Received ICMP packet from %s to %s", id.RemoteAddress, id.LocalAddress) + + // Get the ICMP header from the packet + icmpData := pkt.TransportHeader().Slice() + if len(icmpData) < header.ICMPv4MinimumSize { + logger.Debug("ICMP Handler: Packet too small for ICMP header: %d bytes", len(icmpData)) + return false + } + + icmpHdr := header.ICMPv4(icmpData) + icmpType := icmpHdr.Type() + icmpCode := icmpHdr.Code() + + logger.Debug("ICMP Handler: Type=%d, Code=%d, Ident=%d, Seq=%d", + icmpType, icmpCode, icmpHdr.Ident(), icmpHdr.Sequence()) + + // Only handle Echo Request (ping) + if icmpType != header.ICMPv4Echo { + logger.Debug("ICMP Handler: Ignoring non-echo ICMP type: %d", icmpType) + return false + } + + // Extract source and destination addresses + srcIP := id.RemoteAddress.String() + dstIP := id.LocalAddress.String() + + logger.Info("ICMP Handler: Echo Request from %s to %s (ident=%d, seq=%d)", + srcIP, dstIP, icmpHdr.Ident(), icmpHdr.Sequence()) + + // Convert to netip.Addr for subnet matching + srcAddr, err := netip.ParseAddr(srcIP) + if err != nil { + logger.Debug("ICMP Handler: Failed to parse source IP %s: %v", srcIP, err) + return false + } + dstAddr, err := netip.ParseAddr(dstIP) + if err != nil { + logger.Debug("ICMP Handler: Failed to parse dest IP %s: %v", dstIP, err) + return false + } + + // Check subnet rules (use port 0 for ICMP since it doesn't have ports) + if h.proxyHandler == nil { + logger.Debug("ICMP Handler: No proxy handler configured") + return false + } + + matchedRule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, 0, header.ICMPv4ProtocolNumber) + if matchedRule == nil { + logger.Debug("ICMP Handler: No matching subnet rule for %s -> %s", srcIP, dstIP) + return false + } + + logger.Info("ICMP Handler: Matched subnet rule for %s -> %s", srcIP, dstIP) + + // Determine actual destination (with possible rewrite) + actualDstIP := dstIP + if matchedRule.RewriteTo != "" { + resolvedAddr, err := h.proxyHandler.resolveRewriteAddress(matchedRule.RewriteTo) + if err != nil { + logger.Info("ICMP Handler: Failed to resolve rewrite address %s: %v", matchedRule.RewriteTo, err) + } else { + actualDstIP = resolvedAddr.String() + logger.Info("ICMP Handler: Using rewritten destination %s (original: %s)", actualDstIP, dstIP) + } + } + + // Get the full ICMP payload (including the data after the header) + icmpPayload := pkt.Data().AsRange().ToSlice() + + // Handle the ping in a goroutine to avoid blocking + go h.proxyPing(srcIP, dstIP, actualDstIP, icmpHdr.Ident(), icmpHdr.Sequence(), icmpPayload) + + return true +} + +// proxyPing sends a ping to the actual destination and injects the reply back +func (h *ICMPHandler) proxyPing(srcIP, originalDstIP, actualDstIP string, ident, seq uint16, payload []byte) { + logger.Debug("ICMP Handler: Proxying ping from %s to %s (actual: %s), ident=%d, seq=%d", + srcIP, originalDstIP, actualDstIP, ident, seq) + + // Try three methods in order: ip4:icmp -> udp4 -> ping command + // Track which method succeeded so we can handle identifier matching correctly + method, success := h.tryICMPMethods(actualDstIP, ident, seq, payload) + + if !success { + logger.Info("ICMP Handler: All ping methods failed for %s", actualDstIP) + return + } + + logger.Info("ICMP Handler: Ping successful to %s using %s, injecting reply (ident=%d, seq=%d)", + actualDstIP, method, ident, seq) + + // Build the reply packet to inject back into the netstack + // The reply should appear to come from the original destination (before rewrite) + h.injectICMPReply(srcIP, originalDstIP, ident, seq, payload) +} + +// tryICMPMethods tries all available ICMP methods in order +func (h *ICMPHandler) tryICMPMethods(actualDstIP string, ident, seq uint16, payload []byte) (string, bool) { + if h.tryRawICMP(actualDstIP, ident, seq, payload, false) { + return "raw ICMP", true + } + if h.tryUnprivilegedICMP(actualDstIP, ident, seq, payload) { + return "unprivileged ICMP", true + } + if h.tryPingCommand(actualDstIP, ident, seq, payload) { + return "ping command", true + } + return "", false +} + +// tryRawICMP attempts to ping using raw ICMP sockets (requires CAP_NET_RAW or root) +func (h *ICMPHandler) tryRawICMP(actualDstIP string, ident, seq uint16, payload []byte, ignoreIdent bool) bool { + conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") + if err != nil { + logger.Debug("ICMP Handler: Raw ICMP socket not available: %v", err) + return false + } + defer conn.Close() + + logger.Debug("ICMP Handler: Using raw ICMP socket") + return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, false, ignoreIdent) +} + +// tryUnprivilegedICMP attempts to ping using unprivileged ICMP (requires ping_group_range configured) +func (h *ICMPHandler) tryUnprivilegedICMP(actualDstIP string, ident, seq uint16, payload []byte) bool { + conn, err := icmp.ListenPacket("udp4", "0.0.0.0") + if err != nil { + logger.Debug("ICMP Handler: Unprivileged ICMP socket not available: %v", err) + return false + } + defer conn.Close() + + logger.Debug("ICMP Handler: Using unprivileged ICMP socket") + // Unprivileged ICMP doesn't let us control the identifier, so we ignore it in matching + return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, true, true) +} + +// sendAndReceiveICMP sends an ICMP echo request and waits for the reply +func (h *ICMPHandler) sendAndReceiveICMP(conn *icmp.PacketConn, actualDstIP string, ident, seq uint16, payload []byte, isUnprivileged bool, ignoreIdent bool) bool { + // Build the ICMP echo request message + echoMsg := &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: int(ident), + Seq: int(seq), + Data: payload, + }, + } + + msgBytes, err := echoMsg.Marshal(nil) + if err != nil { + logger.Debug("ICMP Handler: Failed to marshal ICMP message: %v", err) + return false + } + + // Resolve destination address based on socket type + var writeErr error + if isUnprivileged { + // For unprivileged ICMP, use UDP-style addressing + udpAddr := &net.UDPAddr{IP: net.ParseIP(actualDstIP)} + logger.Debug("ICMP Handler: Sending ping to %s (unprivileged)", udpAddr.String()) + conn.SetDeadline(time.Now().Add(icmpTimeout)) + _, writeErr = conn.WriteTo(msgBytes, udpAddr) + } else { + // For raw ICMP, use IP addressing + dst, err := net.ResolveIPAddr("ip4", actualDstIP) + if err != nil { + logger.Debug("ICMP Handler: Failed to resolve destination %s: %v", actualDstIP, err) + return false + } + logger.Debug("ICMP Handler: Sending ping to %s (raw)", dst.String()) + conn.SetDeadline(time.Now().Add(icmpTimeout)) + _, writeErr = conn.WriteTo(msgBytes, dst) + } + + if writeErr != nil { + logger.Debug("ICMP Handler: Failed to send ping to %s: %v", actualDstIP, writeErr) + return false + } + + logger.Debug("ICMP Handler: Ping sent to %s, waiting for reply (ident=%d, seq=%d)", actualDstIP, ident, seq) + + // Wait for reply - loop to filter out non-matching packets + replyBuf := make([]byte, 1500) + + for { + n, peer, err := conn.ReadFrom(replyBuf) + if err != nil { + logger.Debug("ICMP Handler: Failed to receive ping reply from %s: %v", actualDstIP, err) + return false + } + + logger.Debug("ICMP Handler: Received %d bytes from %s", n, peer.String()) + + // Parse the reply + replyMsg, err := icmp.ParseMessage(1, replyBuf[:n]) + if err != nil { + logger.Debug("ICMP Handler: Failed to parse ICMP message: %v", err) + continue + } + + // Check if it's an echo reply (type 0), not an echo request (type 8) + if replyMsg.Type != ipv4.ICMPTypeEchoReply { + logger.Debug("ICMP Handler: Received non-echo-reply type: %v, continuing to wait", replyMsg.Type) + continue + } + + reply, ok := replyMsg.Body.(*icmp.Echo) + if !ok { + logger.Debug("ICMP Handler: Invalid echo reply body type, continuing to wait") + continue + } + + // Verify the sequence matches what we sent + // For unprivileged ICMP, the kernel controls the identifier, so we only check sequence + if reply.Seq != int(seq) { + logger.Debug("ICMP Handler: Reply seq mismatch: got seq=%d, want seq=%d", reply.Seq, seq) + continue + } + + if !ignoreIdent && reply.ID != int(ident) { + logger.Debug("ICMP Handler: Reply ident mismatch: got ident=%d, want ident=%d", reply.ID, ident) + continue + } + + // Found matching reply + logger.Debug("ICMP Handler: Received valid echo reply") + return true + } +} + +// tryPingCommand attempts to ping using the system ping command (always works, but less control) +func (h *ICMPHandler) tryPingCommand(actualDstIP string, ident, seq uint16, payload []byte) bool { + logger.Debug("ICMP Handler: Attempting to use system ping command") + + ctx, cancel := context.WithTimeout(context.Background(), icmpTimeout) + defer cancel() + + // Send one ping with timeout + // -c 1: count = 1 packet + // -W 5: timeout = 5 seconds + // -q: quiet output (just summary) + cmd := exec.CommandContext(ctx, "ping", "-c", "1", "-W", "5", "-q", actualDstIP) + output, err := cmd.CombinedOutput() + + if err != nil { + logger.Debug("ICMP Handler: System ping command failed: %v, output: %s", err, string(output)) + return false + } + + logger.Debug("ICMP Handler: System ping command succeeded") + return true +} + +// injectICMPReply creates an ICMP echo reply packet and queues it to be sent back through the tunnel +func (h *ICMPHandler) injectICMPReply(dstIP, srcIP string, ident, seq uint16, payload []byte) { + logger.Debug("ICMP Handler: Creating reply from %s to %s (ident=%d, seq=%d)", + srcIP, dstIP, ident, seq) + + // Parse addresses + srcAddr, err := netip.ParseAddr(srcIP) + if err != nil { + logger.Info("ICMP Handler: Failed to parse source IP for reply: %v", err) + return + } + dstAddr, err := netip.ParseAddr(dstIP) + if err != nil { + logger.Info("ICMP Handler: Failed to parse dest IP for reply: %v", err) + return + } + + // Calculate total packet size + ipHeaderLen := header.IPv4MinimumSize + icmpHeaderLen := header.ICMPv4MinimumSize + totalLen := ipHeaderLen + icmpHeaderLen + len(payload) + + // Create the packet buffer + pkt := make([]byte, totalLen) + + // Build IPv4 header + ipHdr := header.IPv4(pkt[:ipHeaderLen]) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: tcpip.AddrFrom4(srcAddr.As4()), + DstAddr: tcpip.AddrFrom4(dstAddr.As4()), + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + // Build ICMP header + icmpHdr := header.ICMPv4(pkt[ipHeaderLen : ipHeaderLen+icmpHeaderLen]) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetCode(0) + icmpHdr.SetIdent(ident) + icmpHdr.SetSequence(seq) + + // Copy payload + copy(pkt[ipHeaderLen+icmpHeaderLen:], payload) + + // Calculate ICMP checksum (covers ICMP header + payload) + icmpHdr.SetChecksum(0) + icmpData := pkt[ipHeaderLen:] + icmpHdr.SetChecksum(^checksum.Checksum(icmpData, 0)) + + logger.Debug("ICMP Handler: Built reply packet, total length=%d", totalLen) + + // Queue the packet to be sent back through the tunnel + if h.proxyHandler != nil { + if h.proxyHandler.QueueICMPReply(pkt) { + logger.Info("ICMP Handler: Queued echo reply packet for transmission") + } else { + logger.Info("ICMP Handler: Failed to queue echo reply packet") + } + } else { + logger.Info("ICMP Handler: Cannot queue reply - proxy handler not available") + } +} diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 3338cd0..fefb18d 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -43,6 +43,7 @@ type PortRange struct { type SubnetRule struct { SourcePrefix netip.Prefix // Source IP prefix (who is sending) DestPrefix netip.Prefix // Destination IP prefix (where it's going) + DisableIcmp bool // If true, ICMP traffic is blocked for this subnet RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name PortRanges []PortRange // empty slice means all ports allowed } @@ -69,7 +70,7 @@ func NewSubnetLookup() *SubnetLookup { // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { +func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { sl.mu.Lock() defer sl.mu.Unlock() @@ -81,6 +82,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite sl.rules[key] = &SubnetRule{ SourcePrefix: sourcePrefix, DestPrefix: destPrefix, + DisableIcmp: disableIcmp, RewriteTo: rewriteTo, PortRanges: portRanges, } @@ -123,6 +125,11 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip. continue } + if rule.DisableIcmp && (proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber) { + // ICMP is disabled for this subnet + return nil + } + // Both IPs match - now check port restrictions // If no port ranges specified, all ports are allowed if len(rule.PortRanges) == 0 { @@ -180,23 +187,27 @@ type ProxyHandler struct { proxyNotifyHandle *channel.NotificationHandle tcpHandler *TCPHandler udpHandler *UDPHandler + icmpHandler *ICMPHandler subnetLookup *SubnetLookup natTable map[connKey]*natState destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups natMu sync.RWMutex enabled bool + icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel + notifiable channel.Notification // Notification handler for triggering reads } // ProxyHandlerOptions configures the proxy handler type ProxyHandlerOptions struct { - EnableTCP bool - EnableUDP bool - MTU int + EnableTCP bool + EnableUDP bool + EnableICMP bool + MTU int } // NewProxyHandler creates a new proxy handler for promiscuous mode func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { - if !options.EnableTCP && !options.EnableUDP { + if !options.EnableTCP && !options.EnableUDP && !options.EnableICMP { return nil, nil // No proxy needed } @@ -205,6 +216,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { subnetLookup: NewSubnetLookup(), natTable: make(map[connKey]*natState), destRewriteTable: make(map[destKey]netip.Addr), + icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyStack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -236,6 +248,15 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { } } + // Initialize ICMP handler if enabled + if options.EnableICMP { + handler.icmpHandler = NewICMPHandler(handler.proxyStack, handler) + if err := handler.icmpHandler.InstallICMPHandler(); err != nil { + return nil, fmt.Errorf("failed to install ICMP handler: %v", err) + } + logger.Info("ProxyHandler: ICMP handler enabled") + } + // // Example 1: Add a rule with no port restrictions (all ports allowed) // // This accepts all traffic FROM 10.0.0.0/24 TO 10.20.20.0/24 // sourceSubnet := netip.MustParsePrefix("10.0.0.0/24") @@ -260,11 +281,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { // destPrefix: The IP prefix of the destination // rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name // If portRanges is nil or empty, all ports are allowed for this subnet -func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { +func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { if p == nil || !p.enabled { return } - p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges) + p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) } // RemoveSubnetRule removes a subnet from the proxy handler @@ -343,6 +364,9 @@ func (p *ProxyHandler) Initialize(notifiable channel.Notification) error { return nil } + // Store notifiable for triggering notifications on ICMP replies + p.notifiable = notifiable + // Add notification handler p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable) @@ -421,14 +445,21 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { } udpHeader := header.UDP(packet[headerLen:]) dstPort = udpHeader.DestinationPort() - default: - // For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions) + case header.ICMPv4ProtocolNumber: + // ICMP doesn't have ports, use port 0 (must match rules with no port restrictions) dstPort = 0 + logger.Debug("HandleIncomingPacket: ICMP packet from %s to %s", srcAddr, dstAddr) + default: + // For other protocols, use port 0 (must match rules with no port restrictions) + dstPort = 0 + logger.Debug("HandleIncomingPacket: Unknown protocol %d from %s to %s", protocol, srcAddr, dstAddr) } // Check if the source IP, destination IP, port, and protocol match any subnet rule matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) if matchedRule != nil { + logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)", + srcAddr, dstAddr, protocol, dstPort) // Check if we need to perform DNAT if matchedRule.RewriteTo != "" { // Create connection tracking key using original destination @@ -515,9 +546,12 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { Payload: buffer.MakeWithData(packet), }) p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) + logger.Debug("HandleIncomingPacket: Injected packet into proxy stack (proto=%d)", protocol) return true } + logger.Debug("HandleIncomingPacket: No matching rule for %s -> %s (proto=%d, port=%d)", + srcAddr, dstAddr, protocol, dstPort) return false } @@ -640,6 +674,15 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { return nil } + // First check for ICMP reply packets (non-blocking) + select { + case icmpReply := <-p.icmpReplies: + logger.Debug("ReadOutgoingPacket: Returning ICMP reply packet (%d bytes)", len(icmpReply)) + return buffer.NewViewWithData(icmpReply) + default: + // No ICMP reply available, continue to check proxy endpoint + } + pkt := p.proxyEp.Read() if pkt != nil { view := pkt.ToView() @@ -669,6 +712,11 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { srcPort = udpHeader.SourcePort() dstPort = udpHeader.DestinationPort() } + case header.ICMPv4ProtocolNumber: + // ICMP packets don't need NAT translation in our implementation + // since we construct reply packets with the correct addresses + logger.Debug("ReadOutgoingPacket: ICMP packet from %s to %s", srcIP, dstIP) + return view } // Look up NAT state for reverse translation @@ -702,12 +750,37 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { return nil } +// QueueICMPReply queues an ICMP reply packet to be sent back through the tunnel +func (p *ProxyHandler) QueueICMPReply(packet []byte) bool { + if p == nil || !p.enabled { + return false + } + + select { + case p.icmpReplies <- packet: + logger.Debug("QueueICMPReply: Queued ICMP reply packet (%d bytes)", len(packet)) + // Trigger notification so WriteNotify picks up the packet + if p.notifiable != nil { + p.notifiable.WriteNotify() + } + return true + default: + logger.Info("QueueICMPReply: ICMP reply channel full, dropping packet") + return false + } +} + // Close cleans up the proxy handler resources func (p *ProxyHandler) Close() error { if p == nil || !p.enabled { return nil } + // Close ICMP replies channel + if p.icmpReplies != nil { + close(p.icmpReplies) + } + if p.proxyStack != nil { p.proxyStack.RemoveNIC(1) p.proxyStack.Close() diff --git a/netstack2/tun.go b/netstack2/tun.go index 4bcea65..e743f1e 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -56,15 +56,17 @@ type Net netTun // NetTunOptions contains options for creating a NetTUN device type NetTunOptions struct { - EnableTCPProxy bool - EnableUDPProxy bool + EnableTCPProxy bool + EnableUDPProxy bool + EnableICMPProxy bool } // CreateNetTUN creates a new TUN device with netstack without proxying func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { return CreateNetTUNWithOptions(localAddresses, dnsServers, mtu, NetTunOptions{ - EnableTCPProxy: true, - EnableUDPProxy: true, + EnableTCPProxy: true, + EnableUDPProxy: true, + EnableICMPProxy: true, }) } @@ -84,13 +86,14 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o mtu: mtu, } - // Initialize proxy handler if TCP or UDP proxying is enabled - if options.EnableTCPProxy || options.EnableUDPProxy { + // Initialize proxy handler if TCP, UDP, or ICMP proxying is enabled + if options.EnableTCPProxy || options.EnableUDPProxy || options.EnableICMPProxy { var err error dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{ - EnableTCP: options.EnableTCPProxy, - EnableUDP: options.EnableUDPProxy, - MTU: mtu, + EnableTCP: options.EnableTCPProxy, + EnableUDP: options.EnableUDPProxy, + EnableICMP: options.EnableICMPProxy, + MTU: mtu, }) if err != nil { return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err) @@ -351,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { // AddProxySubnetRule adds a subnet rule to the proxy handler // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { +func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { tun := (*netTun)(net) if tun.proxyHandler != nil { - tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) + tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) } }