From 3a6365782281ac2e4af1ad537442d8f78d046ccf Mon Sep 17 00:00:00 2001 From: Owen Schwartz Date: Mon, 20 Jan 2025 21:11:06 -0500 Subject: [PATCH] Rewrite proxy manager --- proxy/manager.go | 502 ++++++++++++++++++++++++----------------------- proxy/types.go | 29 --- 2 files changed, 253 insertions(+), 278 deletions(-) delete mode 100644 proxy/types.go diff --git a/proxy/manager.go b/proxy/manager.go index ae89b3e..92218fa 100644 --- a/proxy/manager.go +++ b/proxy/manager.go @@ -4,328 +4,332 @@ import ( "fmt" "io" "net" - "strings" "sync" "time" "github.com/fosrl/newt/logger" - "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" - "golang.zx2c4.com/wireguard/tun/netstack" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" ) +// Target represents a proxy target with its address and port +type Target struct { + Address string + Port int +} + +// ProxyManager handles the creation and management of proxy connections +type ProxyManager struct { + tnet *netstack.Net + tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress + udpTargets map[string]map[int]string + listeners []*gonet.TCPListener + udpConns []*gonet.UDPConn + running bool + mutex sync.RWMutex +} + +// NewProxyManager creates a new proxy manager instance func NewProxyManager(tnet *netstack.Net) *ProxyManager { return &ProxyManager{ - tnet: tnet, + tnet: tnet, + tcpTargets: make(map[string]map[int]string), + udpTargets: make(map[string]map[int]string), + listeners: make([]*gonet.TCPListener, 0), + udpConns: make([]*gonet.UDPConn, 0), } } -func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) error { - pm.Lock() - defer pm.Unlock() +// AddTarget adds a new target for proxying +func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error { + pm.mutex.Lock() + defer pm.mutex.Unlock() - logger.Info("Adding target: %s://%s:%d -> %s", protocol, listen, port, target) - newTarget := &ProxyTarget{ - Protocol: protocol, - Listen: listen, - Port: port, - Target: target, - cancel: make(chan struct{}), - done: make(chan struct{}), + switch proto { + case "tcp": + if pm.tcpTargets[listenIP] == nil { + pm.tcpTargets[listenIP] = make(map[int]string) + } + pm.tcpTargets[listenIP][port] = targetAddr + case "udp": + if pm.udpTargets[listenIP] == nil { + pm.udpTargets[listenIP] = make(map[int]string) + } + pm.udpTargets[listenIP][port] = targetAddr + default: + return fmt.Errorf("unsupported protocol: %s", proto) } - pm.targets = append(pm.targets, newTarget) + if pm.running { + return pm.startTarget(proto, listenIP, port, targetAddr) + } else { + logger.Info("Not adding target because not running") + } return nil } -func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error { - pm.Lock() - defer pm.Unlock() +func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error { + pm.mutex.Lock() + defer pm.mutex.Unlock() - protocol = strings.ToLower(protocol) - if protocol != "tcp" && protocol != "udp" { - return fmt.Errorf("unsupported protocol: %s", protocol) - } - - for i, target := range pm.targets { - if target.Listen == listen && - target.Port == port && - strings.ToLower(target.Protocol) == protocol { - - // Signal the serving goroutine to stop - select { - case <-target.cancel: - // Channel is already closed - default: - close(target.cancel) - } - - // Close the listener/connection - target.Lock() - switch protocol { - case "tcp": - if target.listener != nil { - target.listener.Close() - } - case "udp": - if target.udpConn != nil { - target.udpConn.Close() + switch proto { + case "tcp": + if targets, ok := pm.tcpTargets[listenIP]; ok { + delete(targets, port) + // Remove and close the corresponding TCP listener + for i, listener := range pm.listeners { + if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port { + listener.Close() + time.Sleep(50 * time.Millisecond) + // Remove from slice + pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...) + break } } - target.Unlock() - - // Wait for the target to fully stop - <-target.done - - pm.targets = append(pm.targets[:i], pm.targets[i+1:]...) - return nil + } else { + return fmt.Errorf("target not found: %s:%d", listenIP, port) } + case "udp": + if targets, ok := pm.udpTargets[listenIP]; ok { + delete(targets, port) + // Remove and close the corresponding UDP connection + for i, conn := range pm.udpConns { + if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port { + conn.Close() + time.Sleep(50 * time.Millisecond) + // Remove from slice + pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) + break + } + } + } else { + return fmt.Errorf("target not found: %s:%d", listenIP, port) + } + default: + return fmt.Errorf("unsupported protocol: %s", proto) } - - return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port) + return nil } +// Start begins listening for all configured proxy targets func (pm *ProxyManager) Start() error { - pm.RLock() - defer pm.RUnlock() + pm.mutex.Lock() + defer pm.mutex.Unlock() - for _, target := range pm.targets { - target.Lock() - // If target is already running, skip it - if target.listener != nil || target.udpConn != nil { - target.Unlock() - continue - } + if pm.running { + return nil + } - // Mark the target as starting by creating a nil listener/connection - if strings.ToLower(target.Protocol) == "tcp" { - target.listener = nil - } else { - target.udpConn = nil - } - target.Unlock() - - switch strings.ToLower(target.Protocol) { - case "tcp": - go pm.serveTCP(target) - case "udp": - go pm.serveUDP(target) - default: - return fmt.Errorf("unsupported protocol: %s", target.Protocol) + // Start TCP targets + for listenIP, targets := range pm.tcpTargets { + for port, targetAddr := range targets { + if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil { + return fmt.Errorf("failed to start TCP target: %v", err) + } } } + + // Start UDP targets + for listenIP, targets := range pm.udpTargets { + for port, targetAddr := range targets { + if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil { + return fmt.Errorf("failed to start UDP target: %v", err) + } + } + } + + pm.running = true return nil } func (pm *ProxyManager) Stop() error { - pm.Lock() - defer pm.Unlock() + pm.mutex.Lock() + defer pm.mutex.Unlock() - var wg sync.WaitGroup - for _, target := range pm.targets { - wg.Add(1) - // Create a new variable in the loop to avoid closure issues - t := target // Take a local copy - go func() { - defer wg.Done() - close(t.cancel) - t.Lock() - if t.listener != nil { - t.listener.Close() - } - if t.udpConn != nil { - t.udpConn.Close() - } - t.Unlock() - // Wait for the target to fully stop - <-t.done - }() + if !pm.running { + return nil } - wg.Wait() + + // Set running to false first to signal handlers to stop + pm.running = false + + // Close TCP listeners + for i := len(pm.listeners) - 1; i >= 0; i-- { + listener := pm.listeners[i] + if err := listener.Close(); err != nil { + logger.Error("Error closing TCP listener: %v", err) + } + // Remove from slice + pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...) + } + + // Close UDP connections + for i := len(pm.udpConns) - 1; i >= 0; i-- { + conn := pm.udpConns[i] + if err := conn.Close(); err != nil { + logger.Error("Error closing UDP connection: %v", err) + } + // Remove from slice + pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...) + } + + // Clear the target maps + for k := range pm.tcpTargets { + delete(pm.tcpTargets, k) + } + for k := range pm.udpTargets { + delete(pm.udpTargets, k) + } + + // Give active connections a chance to close gracefully + time.Sleep(100 * time.Millisecond) + return nil } -func (pm *ProxyManager) serveTCP(target *ProxyTarget) { - defer close(target.done) // Signal that this target is fully stopped +func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error { + switch proto { + case "tcp": + listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port}) + if err != nil { + return fmt.Errorf("failed to create TCP listener: %v", err) + } - listener, err := pm.tnet.ListenTCP(&net.TCPAddr{ - IP: net.ParseIP(target.Listen), - Port: target.Port, - }) - if err != nil { - logger.Info("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err) - return + pm.listeners = append(pm.listeners, listener) + go pm.handleTCPProxy(listener, targetAddr) + + case "udp": + addr := &net.UDPAddr{Port: port} + conn, err := pm.tnet.ListenUDP(addr) + if err != nil { + return fmt.Errorf("failed to create UDP listener: %v", err) + } + + pm.udpConns = append(pm.udpConns, conn) + go pm.handleUDPProxy(conn, targetAddr) + + default: + return fmt.Errorf("unsupported protocol: %s", proto) } - target.Lock() - target.listener = listener - target.Unlock() + logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr) - defer listener.Close() - logger.Info("TCP proxy listening on %s", listener.Addr()) - - var activeConns sync.WaitGroup - acceptDone := make(chan struct{}) - - // Goroutine to handle shutdown signal - go func() { - <-target.cancel - close(acceptDone) - listener.Close() - }() + return nil +} +func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) { for { conn, err := listener.Accept() if err != nil { - select { - case <-target.cancel: - // Wait for active connections to finish - activeConns.Wait() + // Check if we're shutting down or the listener was closed + if !pm.running { return - default: - logger.Info("Failed to accept TCP connection: %v", err) - // Don't return here, try to accept new connections - time.Sleep(time.Second) - continue } + + // Check for specific network errors that indicate the listener is closed + if ne, ok := err.(net.Error); ok && !ne.Temporary() { + logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr()) + return + } + + logger.Error("Error accepting TCP connection: %v", err) + // Don't hammer the CPU if we hit a temporary error + time.Sleep(100 * time.Millisecond) + continue } - activeConns.Add(1) go func() { - defer activeConns.Done() - pm.handleTCPConnection(conn, target.Target, acceptDone) + target, err := net.Dial("tcp", targetAddr) + if err != nil { + logger.Error("Error connecting to target: %v", err) + conn.Close() + return + } + + // Create a WaitGroup to ensure both copy operations complete + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + io.Copy(target, conn) + target.Close() + }() + + go func() { + defer wg.Done() + io.Copy(conn, target) + conn.Close() + }() + + // Wait for both copies to complete + wg.Wait() }() } } -func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) { - defer clientConn.Close() - - serverConn, err := net.Dial("tcp", target) - if err != nil { - logger.Info("Failed to connect to target %s: %v", target, err) - return - } - defer serverConn.Close() - - // Create error channels for both copy operations - errc1 := make(chan error, 1) - errc2 := make(chan error, 1) - - // Copy from client to server - go func() { - _, err := io.Copy(serverConn, clientConn) - errc1 <- err - }() - - // Copy from server to client - go func() { - _, err := io.Copy(clientConn, serverConn) - errc2 <- err - }() - - // Wait for either copy to finish or done signal - select { - case <-done: - // Gracefully close connections without type assertions - if closer, ok := clientConn.(interface{ CloseRead() error }); ok { - closer.CloseRead() - } - if closer, ok := serverConn.(*gonet.TCPConn); ok { - closer.CloseRead() - } - case err := <-errc1: - if err != nil { - logger.Info("Error copying client->server: %v", err) - } - case err := <-errc2: - if err != nil { - logger.Info("Error copying server->client: %v", err) - } - } -} - -func (pm *ProxyManager) serveUDP(target *ProxyTarget) { - defer close(target.done) // Signal that this target is fully stopped - - addr := &net.UDPAddr{ - IP: net.ParseIP(target.Listen), - Port: target.Port, - } - - conn, err := pm.tnet.ListenUDP(addr) - if err != nil { - logger.Info("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err) - return - } - - target.Lock() - target.udpConn = conn - target.Unlock() - - defer conn.Close() - logger.Info("UDP proxy listening on %s", conn.LocalAddr()) - - buffer := make([]byte, 65535) - var activeConns sync.WaitGroup +func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) { + buffer := make([]byte, 65507) // Max UDP packet size + clientConns := make(map[string]*net.UDPConn) + var clientsMutex sync.RWMutex for { - select { - case <-target.cancel: - activeConns.Wait() // Wait for all active UDP handlers to complete - return - default: - n, remoteAddr, err := conn.ReadFrom(buffer) - if err != nil { - select { - case <-target.cancel: - activeConns.Wait() - return - default: - logger.Info("Failed to read UDP packet: %v", err) - continue - } + n, remoteAddr, err := conn.ReadFrom(buffer) + if err != nil { + if !pm.running { + return } + logger.Error("Error reading UDP packet: %v", err) + continue + } - targetAddr, err := net.ResolveUDPAddr("udp", target.Target) + clientKey := remoteAddr.String() + clientsMutex.RLock() + targetConn, exists := clientConns[clientKey] + clientsMutex.RUnlock() + + if !exists { + targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr) if err != nil { - logger.Info("Failed to resolve target address %s: %v", target.Target, err) + logger.Error("Error resolving target address: %v", err) continue } - activeConns.Add(1) - go func(data []byte, remote net.Addr) { - defer activeConns.Done() - targetConn, err := net.DialUDP("udp", nil, targetAddr) - if err != nil { - logger.Info("Failed to connect to target %s: %v", target.Target, err) - return - } - defer targetConn.Close() + targetConn, err = net.DialUDP("udp", nil, targetUDPAddr) + if err != nil { + logger.Error("Error connecting to target: %v", err) + continue + } - select { - case <-target.cancel: - return - default: - _, err = targetConn.Write(data) + clientsMutex.Lock() + clientConns[clientKey] = targetConn + clientsMutex.Unlock() + + go func() { + buffer := make([]byte, 65507) + for { + n, _, err := targetConn.ReadFromUDP(buffer) if err != nil { - logger.Info("Failed to write to target: %v", err) + logger.Error("Error reading from target: %v", err) return } - response := make([]byte, 65535) - n, err := targetConn.Read(response) + _, err = conn.WriteTo(buffer[:n], remoteAddr) if err != nil { - logger.Info("Failed to read response from target: %v", err) + logger.Error("Error writing to client: %v", err) return } - - _, err = conn.WriteTo(response[:n], remote) - if err != nil { - logger.Info("Failed to write response to client: %v", err) - } } - }(buffer[:n], remoteAddr) + }() + } + + _, err = targetConn.Write(buffer[:n]) + if err != nil { + logger.Error("Error writing to target: %v", err) + targetConn.Close() + clientsMutex.Lock() + delete(clientConns, clientKey) + clientsMutex.Unlock() } } } diff --git a/proxy/types.go b/proxy/types.go deleted file mode 100644 index f189596..0000000 --- a/proxy/types.go +++ /dev/null @@ -1,29 +0,0 @@ -package proxy - -import ( - "log" - "net" - "sync" - - "golang.zx2c4.com/wireguard/tun/netstack" -) - -type ProxyTarget struct { - Protocol string - Listen string - Port int - Target string - cancel chan struct{} // Channel to signal shutdown - done chan struct{} // Channel to signal completion - listener net.Listener // For TCP - udpConn net.PacketConn // For UDP - sync.Mutex // Protect access to connection - activeConns sync.Map -} - -type ProxyManager struct { - targets []*ProxyTarget - tnet *netstack.Net - log *log.Logger - sync.RWMutex // Protect access to targets slice -}