Merge branch 'enhancement/errgroup-context-propagation' of github.com:LaurenceJJones/gerbil into LaurenceJJones-enhancement/errgroup-context-propagation

This commit is contained in:
Owen
2025-12-06 12:14:58 -05:00
3 changed files with 560 additions and 73 deletions

View File

@@ -2,6 +2,7 @@ package relay
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
@@ -141,6 +142,8 @@ type UDPProxyServer struct {
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
privateKey wgtypes.Key
packetChan chan Packet
ctx context.Context
cancel context.CancelFunc
// Session tracking for WireGuard peers
// Key format: "senderIndex:receiverIndex"
@@ -152,14 +155,17 @@ type UDPProxyServer struct {
ReachableAt string
}
// NewUDPProxyServer initializes the server with a buffered packet channel.
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
ctx, cancel := context.WithCancel(parentCtx)
return &UDPProxyServer{
addr: addr,
serverURL: serverURL,
privateKey: privateKey,
packetChan: make(chan Packet, 1000),
ReachableAt: reachableAt,
ctx: ctx,
cancel: cancel,
}
}
@@ -206,19 +212,51 @@ func (s *UDPProxyServer) Start() error {
}
func (s *UDPProxyServer) Stop() {
s.conn.Close()
// Signal all background goroutines to stop
if s.cancel != nil {
s.cancel()
}
// Close listener to unblock reads
if s.conn != nil {
_ = s.conn.Close()
}
// Close all downstream UDP connections
s.connections.Range(func(key, value interface{}) bool {
if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
_ = dc.conn.Close()
}
return true
})
// Close packet channel to stop workers
select {
case <-s.ctx.Done():
default:
}
close(s.packetChan)
}
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
func (s *UDPProxyServer) readPackets() {
for {
// Exit promptly if context is canceled
select {
case <-s.ctx.Done():
return
default:
}
buf := bufferPool.Get().([]byte)
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
if err != nil {
logger.Error("Error reading UDP packet: %v", err)
// Return buffer to pool on read error to avoid leaks
bufferPool.Put(buf[:1500])
continue
// If we're shutting down, exit
select {
case <-s.ctx.Done():
bufferPool.Put(buf[:1500])
return
default:
logger.Error("Error reading UDP packet: %v", err)
bufferPool.Put(buf[:1500])
continue
}
}
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
}
@@ -617,50 +655,69 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
// Add a cleanup method to periodically remove idle connections
func (s *UDPProxyServer) cleanupIdleConnections() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
now := time.Now()
s.connections.Range(func(key, value interface{}) bool {
destConn := value.(*DestinationConn)
if now.Sub(destConn.lastUsed) > 10*time.Minute {
destConn.conn.Close()
s.connections.Delete(key)
}
return true
})
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.connections.Range(func(key, value interface{}) bool {
destConn := value.(*DestinationConn)
if now.Sub(destConn.lastUsed) > 10*time.Minute {
destConn.conn.Close()
s.connections.Delete(key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
// New method to periodically remove idle sessions
func (s *UDPProxyServer) cleanupIdleSessions() {
ticker := time.NewTicker(5 * time.Minute)
for range ticker.C {
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
// Use thread-safe method to read LastSeen
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key)
}
return true
})
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.wgSessions.Range(func(key, value interface{}) bool {
session := value.(*WireGuardSession)
// Use thread-safe method to read LastSeen
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
s.wgSessions.Delete(key)
logger.Debug("Removed idle session: %s", key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
// New method to periodically remove idle proxy mappings
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
ticker := time.NewTicker(10 * time.Minute)
for range ticker.C {
now := time.Now()
s.proxyMappings.Range(func(key, value interface{}) bool {
mapping := value.(ProxyMapping)
// Remove mappings that haven't been used in 30 minutes
if now.Sub(mapping.LastUsed) > 30*time.Minute {
s.proxyMappings.Delete(key)
logger.Debug("Removed idle proxy mapping: %s", key)
}
return true
})
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.proxyMappings.Range(func(key, value interface{}) bool {
mapping := value.(ProxyMapping)
// Remove mappings that haven't been used in 30 minutes
if now.Sub(mapping.LastUsed) > 30*time.Minute {
s.proxyMappings.Delete(key)
logger.Debug("Removed idle proxy mapping: %s", key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}
@@ -972,23 +1029,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
ticker := time.NewTicker(10 * time.Minute)
for range ticker.C {
now := time.Now()
s.commPatterns.Range(func(key, value interface{}) bool {
pattern := value.(*CommunicationPattern)
defer ticker.Stop()
for {
select {
case <-ticker.C:
now := time.Now()
s.commPatterns.Range(func(key, value interface{}) bool {
pattern := value.(*CommunicationPattern)
// Get the most recent activity
lastActivity := pattern.LastFromClient
if pattern.LastFromDest.After(lastActivity) {
lastActivity = pattern.LastFromDest
}
// Get the most recent activity
lastActivity := pattern.LastFromClient
if pattern.LastFromDest.After(lastActivity) {
lastActivity = pattern.LastFromDest
}
// Remove patterns that haven't had activity in 20 minutes
if now.Sub(lastActivity) > 20*time.Minute {
s.commPatterns.Delete(key)
logger.Debug("Removed idle communication pattern: %s", key)
}
return true
})
// Remove patterns that haven't had activity in 20 minutes
if now.Sub(lastActivity) > 20*time.Minute {
s.commPatterns.Delete(key)
logger.Debug("Removed idle communication pattern: %s", key)
}
return true
})
case <-s.ctx.Done():
return
}
}
}