/* SPDX-License-Identifier: MIT * * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved. */ package netstack2 import ( "context" "fmt" "io" "net" "sync" "time" "github.com/fosrl/newt/logger" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) const ( // defaultWndSize if set to zero, the default // receive window buffer size is used instead. defaultWndSize = 0 // maxConnAttempts specifies the maximum number // of in-flight tcp connection attempts. maxConnAttempts = 2 << 10 // tcpKeepaliveCount is the maximum number of // TCP keep-alive probes to send before giving up // and killing the connection if no response is // obtained from the other end. tcpKeepaliveCount = 9 // tcpKeepaliveIdle specifies the time a connection // must remain idle before the first TCP keepalive // packet is sent. Once this time is reached, // tcpKeepaliveInterval option is used instead. tcpKeepaliveIdle = 60 * time.Second // tcpKeepaliveInterval specifies the interval // time between sending TCP keepalive packets. tcpKeepaliveInterval = 30 * time.Second // tcpConnectTimeout is the default timeout for TCP handshakes. tcpConnectTimeout = 5 * time.Second // tcpWaitTimeout implements a TCP half-close timeout. tcpWaitTimeout = 60 * time.Second // udpSessionTimeout is the default timeout for UDP sessions. udpSessionTimeout = 60 * time.Second // Buffer size for copying data bufferSize = 32 * 1024 ) // TCPHandler handles TCP connections from netstack type TCPHandler struct { stack *stack.Stack } // UDPHandler handles UDP connections from netstack type UDPHandler struct { stack *stack.Stack } // NewTCPHandler creates a new TCP handler func NewTCPHandler(s *stack.Stack) *TCPHandler { return &TCPHandler{stack: s} } // NewUDPHandler creates a new UDP handler func NewUDPHandler(s *stack.Stack) *UDPHandler { return &UDPHandler{stack: s} } // InstallTCPHandler installs the TCP forwarder on the stack func (h *TCPHandler) InstallTCPHandler() error { tcpForwarder := tcp.NewForwarder(h.stack, defaultWndSize, maxConnAttempts, func(r *tcp.ForwarderRequest) { var ( wq waiter.Queue ep tcpip.Endpoint err tcpip.Error id = r.ID() ) // Perform a TCP three-way handshake ep, err = r.CreateEndpoint(&wq) if err != nil { // RST: prevent potential half-open TCP connection leak r.Complete(true) return } defer r.Complete(false) // Set socket options setTCPSocketOptions(h.stack, ep) // Create TCP connection from netstack endpoint netstackConn := gonet.NewTCPConn(&wq, ep) // Handle the connection in a goroutine go h.handleTCPConn(netstackConn, id) }) h.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) return nil } // handleTCPConn handles a TCP connection by proxying it to the actual target func (h *TCPHandler) handleTCPConn(netstackConn *gonet.TCPConn, id stack.TransportEndpointID) { defer netstackConn.Close() // Extract source and target address from the connection ID srcIP := id.RemoteAddress.String() srcPort := id.RemotePort dstIP := id.LocalAddress.String() dstPort := id.LocalPort logger.Info("TCP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) targetAddr := fmt.Sprintf("%s:%d", dstIP, dstPort) // Create context with timeout for connection establishment ctx, cancel := context.WithTimeout(context.Background(), tcpConnectTimeout) defer cancel() // Dial the actual target using standard net package var d net.Dialer targetConn, err := d.DialContext(ctx, "tcp", targetAddr) if err != nil { logger.Info("TCP Forwarder: Failed to connect to %s: %v", targetAddr, err) // Connection failed, netstack will handle RST return } defer targetConn.Close() logger.Info("TCP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr) // Bidirectional copy between netstack and target pipeTCP(netstackConn, targetConn) } // pipeTCP copies data bidirectionally between two connections func pipeTCP(origin, remote net.Conn) { wg := sync.WaitGroup{} wg.Add(2) go unidirectionalStreamTCP(remote, origin, "origin->remote", &wg) go unidirectionalStreamTCP(origin, remote, "remote->origin", &wg) wg.Wait() } // unidirectionalStreamTCP copies data in one direction func unidirectionalStreamTCP(dst, src net.Conn, dir string, wg *sync.WaitGroup) { defer wg.Done() buf := make([]byte, bufferSize) _, _ = io.CopyBuffer(dst, src, buf) // Do the upload/download side TCP half-close if cr, ok := src.(interface{ CloseRead() error }); ok { cr.CloseRead() } if cw, ok := dst.(interface{ CloseWrite() error }); ok { cw.CloseWrite() } // Set TCP half-close timeout dst.SetReadDeadline(time.Now().Add(tcpWaitTimeout)) } // setTCPSocketOptions sets TCP socket options for better performance func setTCPSocketOptions(s *stack.Stack, ep tcpip.Endpoint) { // TCP keepalive options ep.SocketOptions().SetKeepAlive(true) idle := tcpip.KeepaliveIdleOption(tcpKeepaliveIdle) ep.SetSockOpt(&idle) interval := tcpip.KeepaliveIntervalOption(tcpKeepaliveInterval) ep.SetSockOpt(&interval) ep.SetSockOptInt(tcpip.KeepaliveCountOption, tcpKeepaliveCount) // TCP send/recv buffer size var ss tcpip.TCPSendBufferSizeRangeOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &ss); err == nil { ep.SocketOptions().SetSendBufferSize(int64(ss.Default), false) } var rs tcpip.TCPReceiveBufferSizeRangeOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &rs); err == nil { ep.SocketOptions().SetReceiveBufferSize(int64(rs.Default), false) } } // InstallUDPHandler installs the UDP forwarder on the stack func (h *UDPHandler) InstallUDPHandler() error { udpForwarder := udp.NewForwarder(h.stack, func(r *udp.ForwarderRequest) { var ( wq waiter.Queue id = r.ID() ) ep, err := r.CreateEndpoint(&wq) if err != nil { return } // Create UDP connection from netstack endpoint netstackConn := gonet.NewUDPConn(&wq, ep) // Handle the connection in a goroutine go h.handleUDPConn(netstackConn, id) }) h.stack.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) return nil } // handleUDPConn handles a UDP connection by proxying it to the actual target func (h *UDPHandler) handleUDPConn(netstackConn *gonet.UDPConn, id stack.TransportEndpointID) { defer netstackConn.Close() // Extract source and target address from the connection ID srcIP := id.RemoteAddress.String() srcPort := id.RemotePort dstIP := id.LocalAddress.String() dstPort := id.LocalPort logger.Info("UDP Forwarder: Handling connection %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) targetAddr := fmt.Sprintf("%s:%d", dstIP, dstPort) // Resolve target address remoteUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) if err != nil { logger.Info("UDP Forwarder: Failed to resolve %s: %v", targetAddr, err) return } // Create UDP connection to target targetConn, err := net.DialUDP("udp", nil, remoteUDPAddr) if err != nil { logger.Info("UDP Forwarder: Failed to dial %s: %v", targetAddr, err) return } defer targetConn.Close() logger.Info("UDP Forwarder: Successfully connected to %s, starting bidirectional copy", targetAddr) // Bidirectional copy between netstack and target pipeUDP(netstackConn, targetConn, remoteUDPAddr, udpSessionTimeout) } // pipeUDP copies UDP packets bidirectionally func pipeUDP(origin, remote net.PacketConn, to net.Addr, timeout time.Duration) { wg := sync.WaitGroup{} wg.Add(2) go unidirectionalPacketStream(remote, origin, to, "origin->remote", &wg, timeout) go unidirectionalPacketStream(origin, remote, nil, "remote->origin", &wg, timeout) wg.Wait() } // unidirectionalPacketStream copies packets in one direction func unidirectionalPacketStream(dst, src net.PacketConn, to net.Addr, dir string, wg *sync.WaitGroup, timeout time.Duration) { defer wg.Done() _ = copyPacketData(dst, src, to, timeout) } // copyPacketData copies UDP packet data with timeout func copyPacketData(dst, src net.PacketConn, to net.Addr, timeout time.Duration) error { buf := make([]byte, 65535) // Max UDP packet size for { src.SetReadDeadline(time.Now().Add(timeout)) n, _, err := src.ReadFrom(buf) if ne, ok := err.(net.Error); ok && ne.Timeout() { return nil // ignore I/O timeout } else if err == io.EOF { return nil // ignore EOF } else if err != nil { return err } if _, err = dst.WriteTo(buf[:n], to); err != nil { return err } dst.SetReadDeadline(time.Now().Add(timeout)) } }