Rewrite proxy manager

This commit is contained in:
Owen Schwartz
2025-01-20 21:11:06 -05:00
parent 759780508a
commit 3a63657822
2 changed files with 253 additions and 278 deletions

View File

@@ -4,328 +4,332 @@ import (
"fmt" "fmt"
"io" "io"
"net" "net"
"strings"
"sync" "sync"
"time" "time"
"github.com/fosrl/newt/logger" "github.com/fosrl/newt/logger"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
) )
// Target represents a proxy target with its address and port
type Target struct {
Address string
Port int
}
// ProxyManager handles the creation and management of proxy connections
type ProxyManager struct {
tnet *netstack.Net
tcpTargets map[string]map[int]string // map[listenIP]map[port]targetAddress
udpTargets map[string]map[int]string
listeners []*gonet.TCPListener
udpConns []*gonet.UDPConn
running bool
mutex sync.RWMutex
}
// NewProxyManager creates a new proxy manager instance
func NewProxyManager(tnet *netstack.Net) *ProxyManager { func NewProxyManager(tnet *netstack.Net) *ProxyManager {
return &ProxyManager{ return &ProxyManager{
tnet: tnet, tnet: tnet,
tcpTargets: make(map[string]map[int]string),
udpTargets: make(map[string]map[int]string),
listeners: make([]*gonet.TCPListener, 0),
udpConns: make([]*gonet.UDPConn, 0),
} }
} }
func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target string) error { // AddTarget adds a new target for proxying
pm.Lock() func (pm *ProxyManager) AddTarget(proto, listenIP string, port int, targetAddr string) error {
defer pm.Unlock() pm.mutex.Lock()
defer pm.mutex.Unlock()
logger.Info("Adding target: %s://%s:%d -> %s", protocol, listen, port, target) switch proto {
newTarget := &ProxyTarget{
Protocol: protocol,
Listen: listen,
Port: port,
Target: target,
cancel: make(chan struct{}),
done: make(chan struct{}),
}
pm.targets = append(pm.targets, newTarget)
return nil
}
func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
pm.Lock()
defer pm.Unlock()
protocol = strings.ToLower(protocol)
if protocol != "tcp" && protocol != "udp" {
return fmt.Errorf("unsupported protocol: %s", protocol)
}
for i, target := range pm.targets {
if target.Listen == listen &&
target.Port == port &&
strings.ToLower(target.Protocol) == protocol {
// Signal the serving goroutine to stop
select {
case <-target.cancel:
// Channel is already closed
default:
close(target.cancel)
}
// Close the listener/connection
target.Lock()
switch protocol {
case "tcp": case "tcp":
if target.listener != nil { if pm.tcpTargets[listenIP] == nil {
target.listener.Close() pm.tcpTargets[listenIP] = make(map[int]string)
} }
pm.tcpTargets[listenIP][port] = targetAddr
case "udp": case "udp":
if target.udpConn != nil { if pm.udpTargets[listenIP] == nil {
target.udpConn.Close() pm.udpTargets[listenIP] = make(map[int]string)
}
}
target.Unlock()
// Wait for the target to fully stop
<-target.done
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
return nil
} }
pm.udpTargets[listenIP][port] = targetAddr
default:
return fmt.Errorf("unsupported protocol: %s", proto)
} }
return fmt.Errorf("target not found for %s %s:%d", protocol, listen, port) if pm.running {
} return pm.startTarget(proto, listenIP, port, targetAddr)
func (pm *ProxyManager) Start() error {
pm.RLock()
defer pm.RUnlock()
for _, target := range pm.targets {
target.Lock()
// If target is already running, skip it
if target.listener != nil || target.udpConn != nil {
target.Unlock()
continue
}
// Mark the target as starting by creating a nil listener/connection
if strings.ToLower(target.Protocol) == "tcp" {
target.listener = nil
} else { } else {
target.udpConn = nil logger.Info("Not adding target because not running")
}
return nil
} }
target.Unlock()
switch strings.ToLower(target.Protocol) { func (pm *ProxyManager) RemoveTarget(proto, listenIP string, port int) error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
switch proto {
case "tcp": case "tcp":
go pm.serveTCP(target) if targets, ok := pm.tcpTargets[listenIP]; ok {
delete(targets, port)
// Remove and close the corresponding TCP listener
for i, listener := range pm.listeners {
if addr, ok := listener.Addr().(*net.TCPAddr); ok && addr.Port == port {
listener.Close()
time.Sleep(50 * time.Millisecond)
// Remove from slice
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
break
}
}
} else {
return fmt.Errorf("target not found: %s:%d", listenIP, port)
}
case "udp": case "udp":
go pm.serveUDP(target) if targets, ok := pm.udpTargets[listenIP]; ok {
delete(targets, port)
// Remove and close the corresponding UDP connection
for i, conn := range pm.udpConns {
if addr, ok := conn.LocalAddr().(*net.UDPAddr); ok && addr.Port == port {
conn.Close()
time.Sleep(50 * time.Millisecond)
// Remove from slice
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
break
}
}
} else {
return fmt.Errorf("target not found: %s:%d", listenIP, port)
}
default: default:
return fmt.Errorf("unsupported protocol: %s", target.Protocol) return fmt.Errorf("unsupported protocol: %s", proto)
}
return nil
}
// Start begins listening for all configured proxy targets
func (pm *ProxyManager) Start() error {
pm.mutex.Lock()
defer pm.mutex.Unlock()
if pm.running {
return nil
}
// Start TCP targets
for listenIP, targets := range pm.tcpTargets {
for port, targetAddr := range targets {
if err := pm.startTarget("tcp", listenIP, port, targetAddr); err != nil {
return fmt.Errorf("failed to start TCP target: %v", err)
} }
} }
}
// Start UDP targets
for listenIP, targets := range pm.udpTargets {
for port, targetAddr := range targets {
if err := pm.startTarget("udp", listenIP, port, targetAddr); err != nil {
return fmt.Errorf("failed to start UDP target: %v", err)
}
}
}
pm.running = true
return nil return nil
} }
func (pm *ProxyManager) Stop() error { func (pm *ProxyManager) Stop() error {
pm.Lock() pm.mutex.Lock()
defer pm.Unlock() defer pm.mutex.Unlock()
var wg sync.WaitGroup if !pm.running {
for _, target := range pm.targets {
wg.Add(1)
// Create a new variable in the loop to avoid closure issues
t := target // Take a local copy
go func() {
defer wg.Done()
close(t.cancel)
t.Lock()
if t.listener != nil {
t.listener.Close()
}
if t.udpConn != nil {
t.udpConn.Close()
}
t.Unlock()
// Wait for the target to fully stop
<-t.done
}()
}
wg.Wait()
return nil return nil
} }
func (pm *ProxyManager) serveTCP(target *ProxyTarget) { // Set running to false first to signal handlers to stop
defer close(target.done) // Signal that this target is fully stopped pm.running = false
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{ // Close TCP listeners
IP: net.ParseIP(target.Listen), for i := len(pm.listeners) - 1; i >= 0; i-- {
Port: target.Port, listener := pm.listeners[i]
}) if err := listener.Close(); err != nil {
if err != nil { logger.Error("Error closing TCP listener: %v", err)
logger.Info("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err) }
return // Remove from slice
pm.listeners = append(pm.listeners[:i], pm.listeners[i+1:]...)
} }
target.Lock() // Close UDP connections
target.listener = listener for i := len(pm.udpConns) - 1; i >= 0; i-- {
target.Unlock() conn := pm.udpConns[i]
if err := conn.Close(); err != nil {
logger.Error("Error closing UDP connection: %v", err)
}
// Remove from slice
pm.udpConns = append(pm.udpConns[:i], pm.udpConns[i+1:]...)
}
defer listener.Close() // Clear the target maps
logger.Info("TCP proxy listening on %s", listener.Addr()) for k := range pm.tcpTargets {
delete(pm.tcpTargets, k)
}
for k := range pm.udpTargets {
delete(pm.udpTargets, k)
}
var activeConns sync.WaitGroup // Give active connections a chance to close gracefully
acceptDone := make(chan struct{}) time.Sleep(100 * time.Millisecond)
// Goroutine to handle shutdown signal return nil
go func() { }
<-target.cancel
close(acceptDone)
listener.Close()
}()
func (pm *ProxyManager) startTarget(proto, listenIP string, port int, targetAddr string) error {
switch proto {
case "tcp":
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{Port: port})
if err != nil {
return fmt.Errorf("failed to create TCP listener: %v", err)
}
pm.listeners = append(pm.listeners, listener)
go pm.handleTCPProxy(listener, targetAddr)
case "udp":
addr := &net.UDPAddr{Port: port}
conn, err := pm.tnet.ListenUDP(addr)
if err != nil {
return fmt.Errorf("failed to create UDP listener: %v", err)
}
pm.udpConns = append(pm.udpConns, conn)
go pm.handleUDPProxy(conn, targetAddr)
default:
return fmt.Errorf("unsupported protocol: %s", proto)
}
logger.Info("Started %s proxy from %s:%d to %s", proto, listenIP, port, targetAddr)
return nil
}
func (pm *ProxyManager) handleTCPProxy(listener net.Listener, targetAddr string) {
for { for {
conn, err := listener.Accept() conn, err := listener.Accept()
if err != nil { if err != nil {
select { // Check if we're shutting down or the listener was closed
case <-target.cancel: if !pm.running {
// Wait for active connections to finish
activeConns.Wait()
return return
default: }
logger.Info("Failed to accept TCP connection: %v", err)
// Don't return here, try to accept new connections // Check for specific network errors that indicate the listener is closed
time.Sleep(time.Second) if ne, ok := err.(net.Error); ok && !ne.Temporary() {
logger.Info("TCP listener closed, stopping proxy handler for %v", listener.Addr())
return
}
logger.Error("Error accepting TCP connection: %v", err)
// Don't hammer the CPU if we hit a temporary error
time.Sleep(100 * time.Millisecond)
continue continue
} }
}
activeConns.Add(1)
go func() { go func() {
defer activeConns.Done() target, err := net.Dial("tcp", targetAddr)
pm.handleTCPConnection(conn, target.Target, acceptDone)
}()
}
}
func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string, done chan struct{}) {
defer clientConn.Close()
serverConn, err := net.Dial("tcp", target)
if err != nil { if err != nil {
logger.Info("Failed to connect to target %s: %v", target, err) logger.Error("Error connecting to target: %v", err)
return conn.Close()
}
defer serverConn.Close()
// Create error channels for both copy operations
errc1 := make(chan error, 1)
errc2 := make(chan error, 1)
// Copy from client to server
go func() {
_, err := io.Copy(serverConn, clientConn)
errc1 <- err
}()
// Copy from server to client
go func() {
_, err := io.Copy(clientConn, serverConn)
errc2 <- err
}()
// Wait for either copy to finish or done signal
select {
case <-done:
// Gracefully close connections without type assertions
if closer, ok := clientConn.(interface{ CloseRead() error }); ok {
closer.CloseRead()
}
if closer, ok := serverConn.(*gonet.TCPConn); ok {
closer.CloseRead()
}
case err := <-errc1:
if err != nil {
logger.Info("Error copying client->server: %v", err)
}
case err := <-errc2:
if err != nil {
logger.Info("Error copying server->client: %v", err)
}
}
}
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
defer close(target.done) // Signal that this target is fully stopped
addr := &net.UDPAddr{
IP: net.ParseIP(target.Listen),
Port: target.Port,
}
conn, err := pm.tnet.ListenUDP(addr)
if err != nil {
logger.Info("Failed to start UDP listener for %s:%d: %v", target.Listen, target.Port, err)
return return
} }
target.Lock() // Create a WaitGroup to ensure both copy operations complete
target.udpConn = conn var wg sync.WaitGroup
target.Unlock() wg.Add(2)
defer conn.Close() go func() {
logger.Info("UDP proxy listening on %s", conn.LocalAddr()) defer wg.Done()
io.Copy(target, conn)
target.Close()
}()
buffer := make([]byte, 65535) go func() {
var activeConns sync.WaitGroup defer wg.Done()
io.Copy(conn, target)
conn.Close()
}()
// Wait for both copies to complete
wg.Wait()
}()
}
}
func (pm *ProxyManager) handleUDPProxy(conn *gonet.UDPConn, targetAddr string) {
buffer := make([]byte, 65507) // Max UDP packet size
clientConns := make(map[string]*net.UDPConn)
var clientsMutex sync.RWMutex
for { for {
select {
case <-target.cancel:
activeConns.Wait() // Wait for all active UDP handlers to complete
return
default:
n, remoteAddr, err := conn.ReadFrom(buffer) n, remoteAddr, err := conn.ReadFrom(buffer)
if err != nil { if err != nil {
select { if !pm.running {
case <-target.cancel:
activeConns.Wait()
return return
default:
logger.Info("Failed to read UDP packet: %v", err)
continue
} }
} logger.Error("Error reading UDP packet: %v", err)
targetAddr, err := net.ResolveUDPAddr("udp", target.Target)
if err != nil {
logger.Info("Failed to resolve target address %s: %v", target.Target, err)
continue continue
} }
activeConns.Add(1) clientKey := remoteAddr.String()
go func(data []byte, remote net.Addr) { clientsMutex.RLock()
defer activeConns.Done() targetConn, exists := clientConns[clientKey]
targetConn, err := net.DialUDP("udp", nil, targetAddr) clientsMutex.RUnlock()
if err != nil {
logger.Info("Failed to connect to target %s: %v", target.Target, err)
return
}
defer targetConn.Close()
select { if !exists {
case <-target.cancel: targetUDPAddr, err := net.ResolveUDPAddr("udp", targetAddr)
return
default:
_, err = targetConn.Write(data)
if err != nil { if err != nil {
logger.Info("Failed to write to target: %v", err) logger.Error("Error resolving target address: %v", err)
continue
}
targetConn, err = net.DialUDP("udp", nil, targetUDPAddr)
if err != nil {
logger.Error("Error connecting to target: %v", err)
continue
}
clientsMutex.Lock()
clientConns[clientKey] = targetConn
clientsMutex.Unlock()
go func() {
buffer := make([]byte, 65507)
for {
n, _, err := targetConn.ReadFromUDP(buffer)
if err != nil {
logger.Error("Error reading from target: %v", err)
return return
} }
response := make([]byte, 65535) _, err = conn.WriteTo(buffer[:n], remoteAddr)
n, err := targetConn.Read(response)
if err != nil { if err != nil {
logger.Info("Failed to read response from target: %v", err) logger.Error("Error writing to client: %v", err)
return return
} }
}
}()
}
_, err = conn.WriteTo(response[:n], remote) _, err = targetConn.Write(buffer[:n])
if err != nil { if err != nil {
logger.Info("Failed to write response to client: %v", err) logger.Error("Error writing to target: %v", err)
} targetConn.Close()
} clientsMutex.Lock()
}(buffer[:n], remoteAddr) delete(clientConns, clientKey)
clientsMutex.Unlock()
} }
} }
} }

View File

@@ -1,29 +0,0 @@
package proxy
import (
"log"
"net"
"sync"
"golang.zx2c4.com/wireguard/tun/netstack"
)
type ProxyTarget struct {
Protocol string
Listen string
Port int
Target string
cancel chan struct{} // Channel to signal shutdown
done chan struct{} // Channel to signal completion
listener net.Listener // For TCP
udpConn net.PacketConn // For UDP
sync.Mutex // Protect access to connection
activeConns sync.Map
}
type ProxyManager struct {
targets []*ProxyTarget
tnet *netstack.Net
log *log.Logger
sync.RWMutex // Protect access to targets slice
}