mirror of
https://github.com/fosrl/newt.git
synced 2026-02-08 05:56:40 +00:00
Handle freeing ports correctly
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
)
|
||||
@@ -27,6 +28,7 @@ func (pm *ProxyManager) AddTarget(protocol, listen string, port int, target stri
|
||||
Port: port,
|
||||
Target: target,
|
||||
cancel: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
pm.targets = append(pm.targets, newTarget)
|
||||
@@ -45,23 +47,42 @@ func (pm *ProxyManager) RemoveTarget(protocol, listen string, port int) error {
|
||||
if target.Listen == listen &&
|
||||
target.Port == port &&
|
||||
strings.ToLower(target.Protocol) == protocol {
|
||||
|
||||
// Signal the serving goroutine to stop
|
||||
// close(target.cancel)
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Channel is already closed, no need to close it again
|
||||
default:
|
||||
close(target.cancel)
|
||||
}
|
||||
|
||||
// Close the appropriate listener/connection based on protocol
|
||||
target.Lock()
|
||||
switch protocol {
|
||||
case "tcp":
|
||||
if target.listener != nil {
|
||||
target.listener.Close()
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Listener was already closed by Stop()
|
||||
default:
|
||||
target.listener.Close()
|
||||
}
|
||||
}
|
||||
case "udp":
|
||||
if target.udpConn != nil {
|
||||
target.udpConn.Close()
|
||||
select {
|
||||
case <-target.cancel:
|
||||
// Connection was already closed by Stop()
|
||||
default:
|
||||
target.udpConn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
target.Unlock()
|
||||
|
||||
// Wait for the target to fully stop
|
||||
<-target.done
|
||||
|
||||
// Remove the target from the slice
|
||||
pm.targets = append(pm.targets[:i], pm.targets[i+1:]...)
|
||||
return nil
|
||||
@@ -76,7 +97,16 @@ func (pm *ProxyManager) Start() error {
|
||||
defer pm.RUnlock()
|
||||
|
||||
for i := range pm.targets {
|
||||
target := &pm.targets[i] // Use pointer to modify the target in the slice
|
||||
target := &pm.targets[i]
|
||||
|
||||
// Skip already running targets
|
||||
target.Lock()
|
||||
if target.listener != nil || target.udpConn != nil {
|
||||
target.Unlock()
|
||||
continue
|
||||
}
|
||||
target.Unlock()
|
||||
|
||||
switch strings.ToLower(target.Protocol) {
|
||||
case "tcp":
|
||||
go pm.serveTCP(target)
|
||||
@@ -93,27 +123,36 @@ func (pm *ProxyManager) Stop() error {
|
||||
pm.Lock()
|
||||
defer pm.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := range pm.targets {
|
||||
target := &pm.targets[i]
|
||||
close(target.cancel)
|
||||
target.Lock()
|
||||
if target.listener != nil {
|
||||
target.listener.Close()
|
||||
}
|
||||
if target.udpConn != nil {
|
||||
target.udpConn.Close()
|
||||
}
|
||||
target.Unlock()
|
||||
wg.Add(1)
|
||||
go func(t *ProxyTarget) {
|
||||
defer wg.Done()
|
||||
close(t.cancel)
|
||||
t.Lock()
|
||||
if t.listener != nil {
|
||||
t.listener.Close()
|
||||
}
|
||||
if t.udpConn != nil {
|
||||
t.udpConn.Close()
|
||||
}
|
||||
t.Unlock()
|
||||
// Wait for the target to fully stop
|
||||
<-t.done
|
||||
}(target)
|
||||
}
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
||||
defer close(target.done) // Signal that this target is fully stopped
|
||||
|
||||
listener, err := pm.tnet.ListenTCP(&net.TCPAddr{
|
||||
IP: net.ParseIP(target.Listen),
|
||||
Port: target.Port,
|
||||
})
|
||||
log.Printf("Listening on %s:%d", target.Listen, target.Port)
|
||||
if err != nil {
|
||||
log.Printf("Failed to start TCP listener for %s:%d: %v", target.Listen, target.Port, err)
|
||||
return
|
||||
@@ -126,14 +165,13 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
||||
defer listener.Close()
|
||||
log.Printf("TCP proxy listening on %s", listener.Addr())
|
||||
|
||||
// Channel to signal active connections to close
|
||||
done := make(chan struct{})
|
||||
var activeConns sync.WaitGroup
|
||||
acceptDone := make(chan struct{})
|
||||
|
||||
// Goroutine to handle shutdown signal
|
||||
go func() {
|
||||
<-target.cancel
|
||||
close(done)
|
||||
close(acceptDone)
|
||||
listener.Close()
|
||||
}()
|
||||
|
||||
@@ -147,6 +185,8 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
||||
return
|
||||
default:
|
||||
log.Printf("Failed to accept TCP connection: %v", err)
|
||||
// Don't return here, try to accept new connections
|
||||
time.Sleep(time.Second)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -154,7 +194,7 @@ func (pm *ProxyManager) serveTCP(target *ProxyTarget) {
|
||||
activeConns.Add(1)
|
||||
go func() {
|
||||
defer activeConns.Done()
|
||||
pm.handleTCPConnection(conn, target.Target, done)
|
||||
pm.handleTCPConnection(conn, target.Target, acceptDone)
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -198,6 +238,8 @@ func (pm *ProxyManager) handleTCPConnection(clientConn net.Conn, target string,
|
||||
}
|
||||
|
||||
func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
|
||||
defer close(target.done) // Signal that this target is fully stopped
|
||||
|
||||
addr := &net.UDPAddr{
|
||||
IP: net.ParseIP(target.Listen),
|
||||
Port: target.Port,
|
||||
@@ -217,16 +259,19 @@ func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
|
||||
log.Printf("UDP proxy listening on %s", conn.LocalAddr())
|
||||
|
||||
buffer := make([]byte, 65535)
|
||||
var activeConns sync.WaitGroup
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
activeConns.Wait() // Wait for all active UDP handlers to complete
|
||||
return
|
||||
default:
|
||||
n, remoteAddr, err := conn.ReadFrom(buffer)
|
||||
if err != nil {
|
||||
select {
|
||||
case <-target.cancel:
|
||||
activeConns.Wait()
|
||||
return
|
||||
default:
|
||||
log.Printf("Failed to read UDP packet: %v", err)
|
||||
@@ -240,7 +285,9 @@ func (pm *ProxyManager) serveUDP(target *ProxyTarget) {
|
||||
continue
|
||||
}
|
||||
|
||||
activeConns.Add(1)
|
||||
go func(data []byte, remote net.Addr) {
|
||||
defer activeConns.Done()
|
||||
targetConn, err := net.DialUDP("udp", nil, targetAddr)
|
||||
if err != nil {
|
||||
log.Printf("Failed to connect to target %s: %v", target.Target, err)
|
||||
|
||||
@@ -14,9 +14,10 @@ type ProxyTarget struct {
|
||||
Port int
|
||||
Target string
|
||||
cancel chan struct{} // Channel to signal shutdown
|
||||
done chan struct{} // Channel to signal completion
|
||||
listener net.Listener // For TCP
|
||||
udpConn net.PacketConn // For UDP
|
||||
sync.Mutex // Protect access to connections
|
||||
sync.Mutex // Protect access to connection
|
||||
}
|
||||
|
||||
type ProxyManager struct {
|
||||
|
||||
Reference in New Issue
Block a user