diff --git a/Dockerfile b/Dockerfile index ad11376..197ac84 100644 --- a/Dockerfile +++ b/Dockerfile @@ -20,7 +20,7 @@ RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w" -o /newt FROM alpine:3.23 AS runner -RUN apk --no-cache add ca-certificates tzdata +RUN apk --no-cache add ca-certificates tzdata iputils COPY --from=builder /newt /usr/local/bin/ COPY entrypoint.sh / diff --git a/clients/clients.go b/clients/clients.go index 17a5398..b2dca47 100644 --- a/clients/clients.go +++ b/clients/clients.go @@ -40,12 +40,14 @@ type Target struct { SourcePrefix string `json:"sourcePrefix"` DestPrefix string `json:"destPrefix"` RewriteTo string `json:"rewriteTo,omitempty"` + DisableIcmp bool `json:"disableIcmp,omitempty"` PortRange []PortRange `json:"portRange,omitempty"` } type PortRange struct { - Min uint16 `json:"min"` - Max uint16 `json:"max"` + Min uint16 `json:"min"` + Max uint16 `json:"max"` + Protocol string `json:"protocol"` // "tcp" or "udp" } type Peer struct { @@ -475,6 +477,8 @@ func (s *WireGuardService) handleConfig(msg websocket.WSMessage) { // Ensure the WireGuard interface and peers are configured if err := s.ensureWireguardInterface(config); err != nil { logger.Error("Failed to ensure WireGuard interface: %v", err) + logger.Error("Clients functionality will be disabled until the interface can be created") + return } if err := s.ensureWireguardPeers(config.Peers); err != nil { @@ -597,8 +601,9 @@ func (s *WireGuardService) ensureWireguardInterface(wgconfig WgConfig) error { s.dns, s.mtu, netstack2.NetTunOptions{ - EnableTCPProxy: true, - EnableUDPProxy: true, + EnableTCPProxy: true, + EnableUDPProxy: true, + EnableICMPProxy: true, }, ) if err != nil { @@ -649,6 +654,11 @@ func (s *WireGuardService) ensureWireguardPeers(peers []Peer) error { // For netstack, we need to manage peers differently // We'll configure peers directly on the device using IPC + // Check if device is initialized + if s.device == nil { + return fmt.Errorf("WireGuard device is not initialized") + } + // First, clear all existing peers by getting current config and removing them currentConfig, err := s.device.IpcGet() if err != nil { @@ -704,12 +714,13 @@ func (s *WireGuardService) ensureTargets(targets []Target) error { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } @@ -1097,10 +1108,11 @@ func (s *WireGuardService) handleAddTarget(msg websocket.WSMessage) { portRanges = append(portRanges, netstack2.PortRange{ Min: pr.Min, Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } @@ -1212,12 +1224,13 @@ func (s *WireGuardService) handleUpdateTarget(msg websocket.WSMessage) { var portRanges []netstack2.PortRange for _, pr := range target.PortRange { portRanges = append(portRanges, netstack2.PortRange{ - Min: pr.Min, - Max: pr.Max, + Min: pr.Min, + Max: pr.Max, + Protocol: pr.Protocol, }) } - s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges) + s.tnet.AddProxySubnetRule(sourcePrefix, destPrefix, target.RewriteTo, portRanges, target.DisableIcmp) logger.Info("Added target subnet from %s to %s rewrite to %s with port ranges: %v", target.SourcePrefix, target.DestPrefix, target.RewriteTo, target.PortRange) } } diff --git a/healthcheck/healthcheck.go b/healthcheck/healthcheck.go index 23ca4bd..8de3008 100644 --- a/healthcheck/healthcheck.go +++ b/healthcheck/healthcheck.go @@ -58,7 +58,7 @@ type Target struct { LastCheck time.Time `json:"lastCheck"` LastError string `json:"lastError,omitempty"` CheckCount int `json:"checkCount"` - ticker *time.Ticker + timer *time.Timer ctx context.Context cancel context.CancelFunc } @@ -304,26 +304,26 @@ func (m *Monitor) monitorTarget(target *Target) { go m.callback(m.GetTargets()) } - // Set up ticker based on current status + // Set up timer based on current status interval := time.Duration(target.Config.Interval) * time.Second if target.Status == StatusUnhealthy { interval = time.Duration(target.Config.UnhealthyInterval) * time.Second } logger.Debug("Target %d: initial check interval set to %v", target.Config.ID, interval) - target.ticker = time.NewTicker(interval) - defer target.ticker.Stop() + target.timer = time.NewTimer(interval) + defer target.timer.Stop() for { select { case <-target.ctx.Done(): logger.Info("Stopping health check monitoring for target %d", target.Config.ID) return - case <-target.ticker.C: + case <-target.timer.C: oldStatus := target.Status m.performHealthCheck(target) - // Update ticker interval if status changed + // Update timer interval if status changed newInterval := time.Duration(target.Config.Interval) * time.Second if target.Status == StatusUnhealthy { newInterval = time.Duration(target.Config.UnhealthyInterval) * time.Second @@ -332,11 +332,12 @@ func (m *Monitor) monitorTarget(target *Target) { if newInterval != interval { logger.Debug("Target %d: updating check interval from %v to %v due to status change", target.Config.ID, interval, newInterval) - target.ticker.Stop() - target.ticker = time.NewTicker(newInterval) interval = newInterval } + // Reset timer for next check with current interval + target.timer.Reset(interval) + // Notify callback if status changed if oldStatus != target.Status && m.callback != nil { logger.Info("Target %d status changed: %s -> %s", diff --git a/holepunch/holepunch.go b/holepunch/holepunch.go index 8ee8767..e6a6cf3 100644 --- a/holepunch/holepunch.go +++ b/holepunch/holepunch.go @@ -22,6 +22,7 @@ type ExitNode struct { Endpoint string `json:"endpoint"` RelayPort uint16 `json:"relayPort"` PublicKey string `json:"publicKey"` + SiteIds []int `json:"siteIds,omitempty"` } // Manager handles UDP hole punching operations @@ -142,6 +143,51 @@ func (m *Manager) RemoveExitNode(endpoint string) bool { return true } +/* +RemoveExitNodesByPeer removes the peer ID from the SiteIds list in each exit node. +If the SiteIds list becomes empty after removal, the exit node is removed entirely. +Returns the number of exit nodes removed. +*/ +func (m *Manager) RemoveExitNodesByPeer(peerID int) int { + m.mu.Lock() + defer m.mu.Unlock() + + removed := 0 + for endpoint, node := range m.exitNodes { + // Remove peerID from SiteIds if present + newSiteIds := make([]int, 0, len(node.SiteIds)) + for _, id := range node.SiteIds { + if id != peerID { + newSiteIds = append(newSiteIds, id) + } + } + if len(newSiteIds) != len(node.SiteIds) { + node.SiteIds = newSiteIds + if len(node.SiteIds) == 0 { + delete(m.exitNodes, endpoint) + logger.Info("Removed exit node %s as no more site IDs remain after removing peer %d", endpoint, peerID) + removed++ + } else { + m.exitNodes[endpoint] = node + logger.Info("Removed peer %d from exit node %s site IDs", peerID, endpoint) + } + } + } + + if removed > 0 { + // Signal the goroutine to refresh if running + if m.running && m.updateChan != nil { + select { + case m.updateChan <- struct{}{}: + default: + // Channel full or closed, skip + } + } + } + + return removed +} + // GetExitNodes returns a copy of the current exit nodes func (m *Manager) GetExitNodes() []ExitNode { m.mu.Lock() diff --git a/main.go b/main.go index 9fdec74..c41ea35 100644 --- a/main.go +++ b/main.go @@ -389,6 +389,13 @@ func runNewtMain(ctx context.Context) { tlsClientCAs = append(tlsClientCAs, tlsClientCAsFlag...) } + if *version { + fmt.Println("Newt version " + newtVersion) + os.Exit(0) + } else { + logger.Info("Newt version %s", newtVersion) + } + logger.Init(nil) loggerLevel := util.ParseLogLevel(logLevel) logger.GetLogger().SetLevel(loggerLevel) @@ -440,13 +447,6 @@ func runNewtMain(ctx context.Context) { defer func() { _ = tel.Shutdown(context.Background()) }() } - if *version { - fmt.Println("Newt version " + newtVersion) - os.Exit(0) - } else { - logger.Info("Newt version %s", newtVersion) - } - if err := updates.CheckForUpdate("fosrl", "newt", newtVersion); err != nil { logger.Error("Error checking for updates: %v\n", err) } diff --git a/netstack2/handlers.go b/netstack2/handlers.go index bdc9feb..014d872 100644 --- a/netstack2/handlers.go +++ b/netstack2/handlers.go @@ -10,12 +10,18 @@ import ( "fmt" "io" "net" + "net/netip" + "os/exec" "sync" "time" "github.com/fosrl/newt/logger" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/checksum" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" @@ -58,6 +64,9 @@ const ( // Buffer size for copying data bufferSize = 32 * 1024 + + // icmpTimeout is the default timeout for ICMP ping requests. + icmpTimeout = 5 * time.Second ) // TCPHandler handles TCP connections from netstack @@ -72,6 +81,12 @@ type UDPHandler struct { proxyHandler *ProxyHandler } +// ICMPHandler handles ICMP packets from netstack +type ICMPHandler struct { + stack *stack.Stack + proxyHandler *ProxyHandler +} + // NewTCPHandler creates a new TCP handler func NewTCPHandler(s *stack.Stack, ph *ProxyHandler) *TCPHandler { return &TCPHandler{stack: s, proxyHandler: ph} @@ -82,6 +97,11 @@ func NewUDPHandler(s *stack.Stack, ph *ProxyHandler) *UDPHandler { return &UDPHandler{stack: s, proxyHandler: ph} } +// NewICMPHandler creates a new ICMP handler +func NewICMPHandler(s *stack.Stack, ph *ProxyHandler) *ICMPHandler { + return &ICMPHandler{stack: s, proxyHandler: ph} +} + // InstallTCPHandler installs the TCP forwarder on the stack func (h *TCPHandler) InstallTCPHandler() error { tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { @@ -348,3 +368,334 @@ func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) dst.SetReadDeadline(time.Now().Add(timeout)) } } + +// InstallICMPHandler installs the ICMP handler on the stack +func (h *ICMPHandler) InstallICMPHandler() error { + h.stack.SetTransportProtocolHandler(header.ICMPv4ProtocolNumber, h.handleICMPPacket) + logger.Info("ICMP Handler: Installed ICMP protocol handler") + return nil +} + +// handleICMPPacket handles incoming ICMP packets +func (h *ICMPHandler) handleICMPPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + logger.Debug("ICMP Handler: Received ICMP packet from %s to %s", id.RemoteAddress, id.LocalAddress) + + // Get the ICMP header from the packet + icmpData := pkt.TransportHeader().Slice() + if len(icmpData) < header.ICMPv4MinimumSize { + logger.Debug("ICMP Handler: Packet too small for ICMP header: %d bytes", len(icmpData)) + return false + } + + icmpHdr := header.ICMPv4(icmpData) + icmpType := icmpHdr.Type() + icmpCode := icmpHdr.Code() + + logger.Debug("ICMP Handler: Type=%d, Code=%d, Ident=%d, Seq=%d", + icmpType, icmpCode, icmpHdr.Ident(), icmpHdr.Sequence()) + + // Only handle Echo Request (ping) + if icmpType != header.ICMPv4Echo { + logger.Debug("ICMP Handler: Ignoring non-echo ICMP type: %d", icmpType) + return false + } + + // Extract source and destination addresses + srcIP := id.RemoteAddress.String() + dstIP := id.LocalAddress.String() + + logger.Info("ICMP Handler: Echo Request from %s to %s (ident=%d, seq=%d)", + srcIP, dstIP, icmpHdr.Ident(), icmpHdr.Sequence()) + + // Convert to netip.Addr for subnet matching + srcAddr, err := netip.ParseAddr(srcIP) + if err != nil { + logger.Debug("ICMP Handler: Failed to parse source IP %s: %v", srcIP, err) + return false + } + dstAddr, err := netip.ParseAddr(dstIP) + if err != nil { + logger.Debug("ICMP Handler: Failed to parse dest IP %s: %v", dstIP, err) + return false + } + + // Check subnet rules (use port 0 for ICMP since it doesn't have ports) + if h.proxyHandler == nil { + logger.Debug("ICMP Handler: No proxy handler configured") + return false + } + + matchedRule := h.proxyHandler.subnetLookup.Match(srcAddr, dstAddr, 0, header.ICMPv4ProtocolNumber) + if matchedRule == nil { + logger.Debug("ICMP Handler: No matching subnet rule for %s -> %s", srcIP, dstIP) + return false + } + + logger.Info("ICMP Handler: Matched subnet rule for %s -> %s", srcIP, dstIP) + + // Determine actual destination (with possible rewrite) + actualDstIP := dstIP + if matchedRule.RewriteTo != "" { + resolvedAddr, err := h.proxyHandler.resolveRewriteAddress(matchedRule.RewriteTo) + if err != nil { + logger.Info("ICMP Handler: Failed to resolve rewrite address %s: %v", matchedRule.RewriteTo, err) + } else { + actualDstIP = resolvedAddr.String() + logger.Info("ICMP Handler: Using rewritten destination %s (original: %s)", actualDstIP, dstIP) + } + } + + // Get the full ICMP payload (including the data after the header) + icmpPayload := pkt.Data().AsRange().ToSlice() + + // Handle the ping in a goroutine to avoid blocking + go h.proxyPing(srcIP, dstIP, actualDstIP, icmpHdr.Ident(), icmpHdr.Sequence(), icmpPayload) + + return true +} + +// proxyPing sends a ping to the actual destination and injects the reply back +func (h *ICMPHandler) proxyPing(srcIP, originalDstIP, actualDstIP string, ident, seq uint16, payload []byte) { + logger.Debug("ICMP Handler: Proxying ping from %s to %s (actual: %s), ident=%d, seq=%d", + srcIP, originalDstIP, actualDstIP, ident, seq) + + // Try three methods in order: ip4:icmp -> udp4 -> ping command + // Track which method succeeded so we can handle identifier matching correctly + method, success := h.tryICMPMethods(actualDstIP, ident, seq, payload) + + if !success { + logger.Info("ICMP Handler: All ping methods failed for %s", actualDstIP) + return + } + + logger.Info("ICMP Handler: Ping successful to %s using %s, injecting reply (ident=%d, seq=%d)", + actualDstIP, method, ident, seq) + + // Build the reply packet to inject back into the netstack + // The reply should appear to come from the original destination (before rewrite) + h.injectICMPReply(srcIP, originalDstIP, ident, seq, payload) +} + +// tryICMPMethods tries all available ICMP methods in order +func (h *ICMPHandler) tryICMPMethods(actualDstIP string, ident, seq uint16, payload []byte) (string, bool) { + if h.tryRawICMP(actualDstIP, ident, seq, payload, false) { + return "raw ICMP", true + } + if h.tryUnprivilegedICMP(actualDstIP, ident, seq, payload) { + return "unprivileged ICMP", true + } + if h.tryPingCommand(actualDstIP, ident, seq, payload) { + return "ping command", true + } + return "", false +} + +// tryRawICMP attempts to ping using raw ICMP sockets (requires CAP_NET_RAW or root) +func (h *ICMPHandler) tryRawICMP(actualDstIP string, ident, seq uint16, payload []byte, ignoreIdent bool) bool { + conn, err := icmp.ListenPacket("ip4:icmp", "0.0.0.0") + if err != nil { + logger.Debug("ICMP Handler: Raw ICMP socket not available: %v", err) + return false + } + defer conn.Close() + + logger.Debug("ICMP Handler: Using raw ICMP socket") + return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, false, ignoreIdent) +} + +// tryUnprivilegedICMP attempts to ping using unprivileged ICMP (requires ping_group_range configured) +func (h *ICMPHandler) tryUnprivilegedICMP(actualDstIP string, ident, seq uint16, payload []byte) bool { + conn, err := icmp.ListenPacket("udp4", "0.0.0.0") + if err != nil { + logger.Debug("ICMP Handler: Unprivileged ICMP socket not available: %v", err) + return false + } + defer conn.Close() + + logger.Debug("ICMP Handler: Using unprivileged ICMP socket") + // Unprivileged ICMP doesn't let us control the identifier, so we ignore it in matching + return h.sendAndReceiveICMP(conn, actualDstIP, ident, seq, payload, true, true) +} + +// sendAndReceiveICMP sends an ICMP echo request and waits for the reply +func (h *ICMPHandler) sendAndReceiveICMP(conn *icmp.PacketConn, actualDstIP string, ident, seq uint16, payload []byte, isUnprivileged bool, ignoreIdent bool) bool { + // Build the ICMP echo request message + echoMsg := &icmp.Message{ + Type: ipv4.ICMPTypeEcho, + Code: 0, + Body: &icmp.Echo{ + ID: int(ident), + Seq: int(seq), + Data: payload, + }, + } + + msgBytes, err := echoMsg.Marshal(nil) + if err != nil { + logger.Debug("ICMP Handler: Failed to marshal ICMP message: %v", err) + return false + } + + // Resolve destination address based on socket type + var writeErr error + if isUnprivileged { + // For unprivileged ICMP, use UDP-style addressing + udpAddr := &net.UDPAddr{IP: net.ParseIP(actualDstIP)} + logger.Debug("ICMP Handler: Sending ping to %s (unprivileged)", udpAddr.String()) + conn.SetDeadline(time.Now().Add(icmpTimeout)) + _, writeErr = conn.WriteTo(msgBytes, udpAddr) + } else { + // For raw ICMP, use IP addressing + dst, err := net.ResolveIPAddr("ip4", actualDstIP) + if err != nil { + logger.Debug("ICMP Handler: Failed to resolve destination %s: %v", actualDstIP, err) + return false + } + logger.Debug("ICMP Handler: Sending ping to %s (raw)", dst.String()) + conn.SetDeadline(time.Now().Add(icmpTimeout)) + _, writeErr = conn.WriteTo(msgBytes, dst) + } + + if writeErr != nil { + logger.Debug("ICMP Handler: Failed to send ping to %s: %v", actualDstIP, writeErr) + return false + } + + logger.Debug("ICMP Handler: Ping sent to %s, waiting for reply (ident=%d, seq=%d)", actualDstIP, ident, seq) + + // Wait for reply - loop to filter out non-matching packets + replyBuf := make([]byte, 1500) + + for { + n, peer, err := conn.ReadFrom(replyBuf) + if err != nil { + logger.Debug("ICMP Handler: Failed to receive ping reply from %s: %v", actualDstIP, err) + return false + } + + logger.Debug("ICMP Handler: Received %d bytes from %s", n, peer.String()) + + // Parse the reply + replyMsg, err := icmp.ParseMessage(1, replyBuf[:n]) + if err != nil { + logger.Debug("ICMP Handler: Failed to parse ICMP message: %v", err) + continue + } + + // Check if it's an echo reply (type 0), not an echo request (type 8) + if replyMsg.Type != ipv4.ICMPTypeEchoReply { + logger.Debug("ICMP Handler: Received non-echo-reply type: %v, continuing to wait", replyMsg.Type) + continue + } + + reply, ok := replyMsg.Body.(*icmp.Echo) + if !ok { + logger.Debug("ICMP Handler: Invalid echo reply body type, continuing to wait") + continue + } + + // Verify the sequence matches what we sent + // For unprivileged ICMP, the kernel controls the identifier, so we only check sequence + if reply.Seq != int(seq) { + logger.Debug("ICMP Handler: Reply seq mismatch: got seq=%d, want seq=%d", reply.Seq, seq) + continue + } + + if !ignoreIdent && reply.ID != int(ident) { + logger.Debug("ICMP Handler: Reply ident mismatch: got ident=%d, want ident=%d", reply.ID, ident) + continue + } + + // Found matching reply + logger.Debug("ICMP Handler: Received valid echo reply") + return true + } +} + +// tryPingCommand attempts to ping using the system ping command (always works, but less control) +func (h *ICMPHandler) tryPingCommand(actualDstIP string, ident, seq uint16, payload []byte) bool { + logger.Debug("ICMP Handler: Attempting to use system ping command") + + ctx, cancel := context.WithTimeout(context.Background(), icmpTimeout) + defer cancel() + + // Send one ping with timeout + // -c 1: count = 1 packet + // -W 5: timeout = 5 seconds + // -q: quiet output (just summary) + cmd := exec.CommandContext(ctx, "ping", "-c", "1", "-W", "5", "-q", actualDstIP) + output, err := cmd.CombinedOutput() + + if err != nil { + logger.Debug("ICMP Handler: System ping command failed: %v, output: %s", err, string(output)) + return false + } + + logger.Debug("ICMP Handler: System ping command succeeded") + return true +} + +// injectICMPReply creates an ICMP echo reply packet and queues it to be sent back through the tunnel +func (h *ICMPHandler) injectICMPReply(dstIP, srcIP string, ident, seq uint16, payload []byte) { + logger.Debug("ICMP Handler: Creating reply from %s to %s (ident=%d, seq=%d)", + srcIP, dstIP, ident, seq) + + // Parse addresses + srcAddr, err := netip.ParseAddr(srcIP) + if err != nil { + logger.Info("ICMP Handler: Failed to parse source IP for reply: %v", err) + return + } + dstAddr, err := netip.ParseAddr(dstIP) + if err != nil { + logger.Info("ICMP Handler: Failed to parse dest IP for reply: %v", err) + return + } + + // Calculate total packet size + ipHeaderLen := header.IPv4MinimumSize + icmpHeaderLen := header.ICMPv4MinimumSize + totalLen := ipHeaderLen + icmpHeaderLen + len(payload) + + // Create the packet buffer + pkt := make([]byte, totalLen) + + // Build IPv4 header + ipHdr := header.IPv4(pkt[:ipHeaderLen]) + ipHdr.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLen), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: tcpip.AddrFrom4(srcAddr.As4()), + DstAddr: tcpip.AddrFrom4(dstAddr.As4()), + }) + ipHdr.SetChecksum(^ipHdr.CalculateChecksum()) + + // Build ICMP header + icmpHdr := header.ICMPv4(pkt[ipHeaderLen : ipHeaderLen+icmpHeaderLen]) + icmpHdr.SetType(header.ICMPv4EchoReply) + icmpHdr.SetCode(0) + icmpHdr.SetIdent(ident) + icmpHdr.SetSequence(seq) + + // Copy payload + copy(pkt[ipHeaderLen+icmpHeaderLen:], payload) + + // Calculate ICMP checksum (covers ICMP header + payload) + icmpHdr.SetChecksum(0) + icmpData := pkt[ipHeaderLen:] + icmpHdr.SetChecksum(^checksum.Checksum(icmpData, 0)) + + logger.Debug("ICMP Handler: Built reply packet, total length=%d", totalLen) + + // Queue the packet to be sent back through the tunnel + if h.proxyHandler != nil { + if h.proxyHandler.QueueICMPReply(pkt) { + logger.Info("ICMP Handler: Queued echo reply packet for transmission") + } else { + logger.Info("ICMP Handler: Failed to queue echo reply packet") + } + } else { + logger.Info("ICMP Handler: Cannot queue reply - proxy handler not available") + } +} diff --git a/netstack2/proxy.go b/netstack2/proxy.go index 77a9d23..fefb18d 100644 --- a/netstack2/proxy.go +++ b/netstack2/proxy.go @@ -22,10 +22,12 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -// PortRange represents an allowed range of ports (inclusive) +// PortRange represents an allowed range of ports (inclusive) with optional protocol filtering +// Protocol can be "tcp", "udp", or "" (empty string means both protocols) type PortRange struct { - Min uint16 - Max uint16 + Min uint16 + Max uint16 + Protocol string // "tcp", "udp", or "" for both } // SubnetRule represents a subnet with optional port restrictions and source address @@ -41,6 +43,7 @@ type PortRange struct { type SubnetRule struct { SourcePrefix netip.Prefix // Source IP prefix (who is sending) DestPrefix netip.Prefix // Destination IP prefix (where it's going) + DisableIcmp bool // If true, ICMP traffic is blocked for this subnet RewriteTo string // Optional rewrite address for DNAT - can be IP/CIDR or domain name PortRanges []PortRange // empty slice means all ports allowed } @@ -67,7 +70,7 @@ func NewSubnetLookup() *SubnetLookup { // AddSubnet adds a subnet rule with source and destination prefixes and optional port restrictions // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { +func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { sl.mu.Lock() defer sl.mu.Unlock() @@ -79,6 +82,7 @@ func (sl *SubnetLookup) AddSubnet(sourcePrefix, destPrefix netip.Prefix, rewrite sl.rules[key] = &SubnetRule{ SourcePrefix: sourcePrefix, DestPrefix: destPrefix, + DisableIcmp: disableIcmp, RewriteTo: rewriteTo, PortRanges: portRanges, } @@ -97,14 +101,16 @@ func (sl *SubnetLookup) RemoveSubnet(sourcePrefix, destPrefix netip.Prefix) { delete(sl.rules, key) } -// Match checks if a source IP, destination IP, and port match any subnet rule -// Returns the matched rule if BOTH: +// Match checks if a source IP, destination IP, port, and protocol match any subnet rule +// Returns the matched rule if ALL of these conditions are met: // - The source IP is in the rule's source prefix // - The destination IP is in the rule's destination prefix // - The port is in an allowed range (or no port restrictions exist) +// - The protocol matches (or the port range allows both protocols) // +// proto should be header.TCPProtocolNumber or header.UDPProtocolNumber // Returns nil if no rule matches -func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule { +func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16, proto tcpip.TransportProtocolNumber) *SubnetRule { sl.mu.RLock() defer sl.mu.RUnlock() @@ -119,16 +125,31 @@ func (sl *SubnetLookup) Match(srcIP, dstIP netip.Addr, port uint16) *SubnetRule continue } + if rule.DisableIcmp && (proto == header.ICMPv4ProtocolNumber || proto == header.ICMPv6ProtocolNumber) { + // ICMP is disabled for this subnet + return nil + } + // Both IPs match - now check port restrictions // If no port ranges specified, all ports are allowed if len(rule.PortRanges) == 0 { return rule } - // Check if port is in any of the allowed ranges + // Check if port and protocol are in any of the allowed ranges for _, pr := range rule.PortRanges { if port >= pr.Min && port <= pr.Max { - return rule + // Check protocol compatibility + if pr.Protocol == "" { + // Empty protocol means allow both TCP and UDP + return rule + } + // Check if the packet protocol matches the port range protocol + if (pr.Protocol == "tcp" && proto == header.TCPProtocolNumber) || + (pr.Protocol == "udp" && proto == header.UDPProtocolNumber) { + return rule + } + // Port matches but protocol doesn't - continue checking other ranges } } } @@ -166,23 +187,27 @@ type ProxyHandler struct { proxyNotifyHandle *channel.NotificationHandle tcpHandler *TCPHandler udpHandler *UDPHandler + icmpHandler *ICMPHandler subnetLookup *SubnetLookup natTable map[connKey]*natState destRewriteTable map[destKey]netip.Addr // Maps original dest to rewritten dest for handler lookups natMu sync.RWMutex enabled bool + icmpReplies chan []byte // Channel for ICMP reply packets to be sent back through the tunnel + notifiable channel.Notification // Notification handler for triggering reads } // ProxyHandlerOptions configures the proxy handler type ProxyHandlerOptions struct { - EnableTCP bool - EnableUDP bool - MTU int + EnableTCP bool + EnableUDP bool + EnableICMP bool + MTU int } // NewProxyHandler creates a new proxy handler for promiscuous mode func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { - if !options.EnableTCP && !options.EnableUDP { + if !options.EnableTCP && !options.EnableUDP && !options.EnableICMP { return nil, nil // No proxy needed } @@ -191,6 +216,7 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { subnetLookup: NewSubnetLookup(), natTable: make(map[connKey]*natState), destRewriteTable: make(map[destKey]netip.Addr), + icmpReplies: make(chan []byte, 256), // Buffer for ICMP reply packets proxyEp: channel.New(1024, uint32(options.MTU), ""), proxyStack: stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ @@ -222,6 +248,15 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { } } + // Initialize ICMP handler if enabled + if options.EnableICMP { + handler.icmpHandler = NewICMPHandler(handler.proxyStack, handler) + if err := handler.icmpHandler.InstallICMPHandler(); err != nil { + return nil, fmt.Errorf("failed to install ICMP handler: %v", err) + } + logger.Info("ProxyHandler: ICMP handler enabled") + } + // // Example 1: Add a rule with no port restrictions (all ports allowed) // // This accepts all traffic FROM 10.0.0.0/24 TO 10.20.20.0/24 // sourceSubnet := netip.MustParsePrefix("10.0.0.0/24") @@ -246,11 +281,11 @@ func NewProxyHandler(options ProxyHandlerOptions) (*ProxyHandler, error) { // destPrefix: The IP prefix of the destination // rewriteTo: Optional address to rewrite destination to - can be IP/CIDR or domain name // If portRanges is nil or empty, all ports are allowed for this subnet -func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { +func (p *ProxyHandler) AddSubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { if p == nil || !p.enabled { return } - p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges) + p.subnetLookup.AddSubnet(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) } // RemoveSubnetRule removes a subnet from the proxy handler @@ -329,6 +364,9 @@ func (p *ProxyHandler) Initialize(notifiable channel.Notification) error { return nil } + // Store notifiable for triggering notifications on ICMP replies + p.notifiable = notifiable + // Add notification handler p.proxyNotifyHandle = p.proxyEp.AddNotify(notifiable) @@ -407,14 +445,21 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { } udpHeader := header.UDP(packet[headerLen:]) dstPort = udpHeader.DestinationPort() - default: - // For other protocols (ICMP, etc.), use port 0 (must match rules with no port restrictions) + case header.ICMPv4ProtocolNumber: + // ICMP doesn't have ports, use port 0 (must match rules with no port restrictions) dstPort = 0 + logger.Debug("HandleIncomingPacket: ICMP packet from %s to %s", srcAddr, dstAddr) + default: + // For other protocols, use port 0 (must match rules with no port restrictions) + dstPort = 0 + logger.Debug("HandleIncomingPacket: Unknown protocol %d from %s to %s", protocol, srcAddr, dstAddr) } - // Check if the source IP, destination IP, and port match any subnet rule - matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort) + // Check if the source IP, destination IP, port, and protocol match any subnet rule + matchedRule := p.subnetLookup.Match(srcAddr, dstAddr, dstPort, protocol) if matchedRule != nil { + logger.Debug("HandleIncomingPacket: Matched rule for %s -> %s (proto=%d, port=%d)", + srcAddr, dstAddr, protocol, dstPort) // Check if we need to perform DNAT if matchedRule.RewriteTo != "" { // Create connection tracking key using original destination @@ -501,9 +546,12 @@ func (p *ProxyHandler) HandleIncomingPacket(packet []byte) bool { Payload: buffer.MakeWithData(packet), }) p.proxyEp.InjectInbound(header.IPv4ProtocolNumber, pkb) + logger.Debug("HandleIncomingPacket: Injected packet into proxy stack (proto=%d)", protocol) return true } + logger.Debug("HandleIncomingPacket: No matching rule for %s -> %s (proto=%d, port=%d)", + srcAddr, dstAddr, protocol, dstPort) return false } @@ -626,6 +674,15 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { return nil } + // First check for ICMP reply packets (non-blocking) + select { + case icmpReply := <-p.icmpReplies: + logger.Debug("ReadOutgoingPacket: Returning ICMP reply packet (%d bytes)", len(icmpReply)) + return buffer.NewViewWithData(icmpReply) + default: + // No ICMP reply available, continue to check proxy endpoint + } + pkt := p.proxyEp.Read() if pkt != nil { view := pkt.ToView() @@ -655,6 +712,11 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { srcPort = udpHeader.SourcePort() dstPort = udpHeader.DestinationPort() } + case header.ICMPv4ProtocolNumber: + // ICMP packets don't need NAT translation in our implementation + // since we construct reply packets with the correct addresses + logger.Debug("ReadOutgoingPacket: ICMP packet from %s to %s", srcIP, dstIP) + return view } // Look up NAT state for reverse translation @@ -688,12 +750,37 @@ func (p *ProxyHandler) ReadOutgoingPacket() *buffer.View { return nil } +// QueueICMPReply queues an ICMP reply packet to be sent back through the tunnel +func (p *ProxyHandler) QueueICMPReply(packet []byte) bool { + if p == nil || !p.enabled { + return false + } + + select { + case p.icmpReplies <- packet: + logger.Debug("QueueICMPReply: Queued ICMP reply packet (%d bytes)", len(packet)) + // Trigger notification so WriteNotify picks up the packet + if p.notifiable != nil { + p.notifiable.WriteNotify() + } + return true + default: + logger.Info("QueueICMPReply: ICMP reply channel full, dropping packet") + return false + } +} + // Close cleans up the proxy handler resources func (p *ProxyHandler) Close() error { if p == nil || !p.enabled { return nil } + // Close ICMP replies channel + if p.icmpReplies != nil { + close(p.icmpReplies) + } + if p.proxyStack != nil { p.proxyStack.RemoveNIC(1) p.proxyStack.Close() diff --git a/netstack2/tun.go b/netstack2/tun.go index 4bcea65..e743f1e 100644 --- a/netstack2/tun.go +++ b/netstack2/tun.go @@ -56,15 +56,17 @@ type Net netTun // NetTunOptions contains options for creating a NetTUN device type NetTunOptions struct { - EnableTCPProxy bool - EnableUDPProxy bool + EnableTCPProxy bool + EnableUDPProxy bool + EnableICMPProxy bool } // CreateNetTUN creates a new TUN device with netstack without proxying func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) { return CreateNetTUNWithOptions(localAddresses, dnsServers, mtu, NetTunOptions{ - EnableTCPProxy: true, - EnableUDPProxy: true, + EnableTCPProxy: true, + EnableUDPProxy: true, + EnableICMPProxy: true, }) } @@ -84,13 +86,14 @@ func CreateNetTUNWithOptions(localAddresses, dnsServers []netip.Addr, mtu int, o mtu: mtu, } - // Initialize proxy handler if TCP or UDP proxying is enabled - if options.EnableTCPProxy || options.EnableUDPProxy { + // Initialize proxy handler if TCP, UDP, or ICMP proxying is enabled + if options.EnableTCPProxy || options.EnableUDPProxy || options.EnableICMPProxy { var err error dev.proxyHandler, err = NewProxyHandler(ProxyHandlerOptions{ - EnableTCP: options.EnableTCPProxy, - EnableUDP: options.EnableUDPProxy, - MTU: mtu, + EnableTCP: options.EnableTCPProxy, + EnableUDP: options.EnableUDPProxy, + EnableICMP: options.EnableICMPProxy, + MTU: mtu, }) if err != nil { return nil, nil, fmt.Errorf("failed to create proxy handler: %v", err) @@ -351,10 +354,10 @@ func (net *Net) ListenUDP(laddr *net.UDPAddr) (*gonet.UDPConn, error) { // AddProxySubnetRule adds a subnet rule to the proxy handler // If portRanges is nil or empty, all ports are allowed for this subnet // rewriteTo can be either an IP/CIDR (e.g., "192.168.1.1/32") or a domain name (e.g., "example.com") -func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange) { +func (net *Net) AddProxySubnetRule(sourcePrefix, destPrefix netip.Prefix, rewriteTo string, portRanges []PortRange, disableIcmp bool) { tun := (*netTun)(net) if tun.proxyHandler != nil { - tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges) + tun.proxyHandler.AddSubnetRule(sourcePrefix, destPrefix, rewriteTo, portRanges, disableIcmp) } } diff --git a/udp_client.py b/udp_client.py new file mode 100644 index 0000000..2909d13 --- /dev/null +++ b/udp_client.py @@ -0,0 +1,49 @@ +import socket +import sys + +# Argument parsing: Check if IP and Port are provided +if len(sys.argv) != 3: + print("Usage: python udp_client.py ") + # Example: python udp_client.py 127.0.0.1 12000 + sys.exit(1) + +HOST = sys.argv[1] +try: + PORT = int(sys.argv[2]) +except ValueError: + print("Error: HOST_PORT must be an integer.") + sys.exit(1) + +# The message to send to the server +MESSAGE = "Hello UDP Server! How are you?" + +# Create a UDP socket +try: + client_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +except socket.error as err: + print(f"Failed to create socket: {err}") + sys.exit() + +try: + print(f"Sending message to {HOST}:{PORT}...") + + # Send the message (data must be encoded to bytes) + client_socket.sendto(MESSAGE.encode('utf-8'), (HOST, PORT)) + + # Wait for the server's response (buffer size 1024 bytes) + data, server_address = client_socket.recvfrom(1024) + + # Decode and print the server's response + response = data.decode('utf-8') + print("-" * 30) + print(f"Received response from server {server_address[0]}:{server_address[1]}:") + print(f"-> Data: '{response}'") + +except socket.error as err: + print(f"Error during communication: {err}") + +finally: + # Close the socket + client_socket.close() + print("-" * 30) + print("Client finished and socket closed.") diff --git a/udp_server.py b/udp_server.py new file mode 100644 index 0000000..89aea28 --- /dev/null +++ b/udp_server.py @@ -0,0 +1,58 @@ +import socket +import sys + +# optionally take in some positional args for the port +if len(sys.argv) > 1: + try: + PORT = int(sys.argv[1]) + except ValueError: + print("Invalid port number. Using default port 12000.") + PORT = 12000 +else: + PORT = 12000 + +# Define the server host and port +HOST = '0.0.0.0' # Standard loopback interface address (localhost) + +# Create a UDP socket +try: + server_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) +except socket.error as err: + print(f"Failed to create socket: {err}") + sys.exit() + +# Bind the socket to the address +try: + server_socket.bind((HOST, PORT)) + print(f"UDP Server listening on {HOST}:{PORT}") +except socket.error as err: + print(f"Bind failed: {err}") + server_socket.close() + sys.exit() + +# Wait for and process incoming data +while True: + try: + # Receive data and the client's address (buffer size 1024 bytes) + data, client_address = server_socket.recvfrom(1024) + + # Decode the data and print the message + message = data.decode('utf-8') + print("-" * 30) + print(f"Received message from {client_address[0]}:{client_address[1]}:") + print(f"-> Data: '{message}'") + + # Prepare the response message + response_message = f"Hello client! Server received: '{message.upper()}'" + + # Send the response back to the client + server_socket.sendto(response_message.encode('utf-8'), client_address) + print(f"Sent response back to client.") + + except Exception as e: + print(f"An error occurred: {e}") + break + +# Clean up (though usually unreachable in an infinite server loop) +server_socket.close() +print("Server stopped.")