mirror of
https://github.com/fosrl/gerbil.git
synced 2026-02-07 21:46:40 +00:00
Merge branch 'LaurenceJJones-enhancement/errgroup-context-propagation'
This commit is contained in:
385
37.diff
Normal file
385
37.diff
Normal file
@@ -0,0 +1,385 @@
|
||||
diff --git a/main.go b/main.go
|
||||
index 7a99c4d..61c186f 100644
|
||||
--- a/main.go
|
||||
+++ b/main.go
|
||||
@@ -2,7 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
+ "context"
|
||||
"encoding/json"
|
||||
+ "errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
"github.com/vishvananda/netlink"
|
||||
+ "golang.org/x/sync/errgroup"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -217,6 +220,10 @@ func main() {
|
||||
logger.Init()
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
+ // Base context for the application; cancel on SIGINT/SIGTERM
|
||||
+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
+ defer stop()
|
||||
+
|
||||
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
|
||||
if reachableAt != "" && listenAddr == "" {
|
||||
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
|
||||
@@ -324,10 +331,16 @@ func main() {
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
- go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
||||
+ // Child error group derived from base context
|
||||
+ group, groupCtx := errgroup.WithContext(ctx)
|
||||
+
|
||||
+ // Periodic bandwidth reporting
|
||||
+ group.Go(func() error {
|
||||
+ return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
|
||||
+ })
|
||||
|
||||
// Start the UDP proxy server
|
||||
- proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
||||
+ proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
@@ -371,18 +384,39 @@ func main() {
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
- // Run HTTP server in a goroutine
|
||||
- go func() {
|
||||
- if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
- logger.Error("HTTP server failed: %v", err)
|
||||
+ // HTTP server with graceful shutdown on context cancel
|
||||
+ server := &http.Server{
|
||||
+ Addr: listenAddr,
|
||||
+ Handler: nil,
|
||||
+ }
|
||||
+ group.Go(func() error {
|
||||
+ // http.ErrServerClosed is returned on graceful shutdown; not an error for us
|
||||
+ if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
+ return err
|
||||
+ }
|
||||
+ return nil
|
||||
+ })
|
||||
+ group.Go(func() error {
|
||||
+ <-groupCtx.Done()
|
||||
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
+ defer cancel()
|
||||
+ _ = server.Shutdown(shutdownCtx)
|
||||
+ // Stop background components as the context is canceled
|
||||
+ if proxySNI != nil {
|
||||
+ _ = proxySNI.Stop()
|
||||
+ }
|
||||
+ if proxyRelay != nil {
|
||||
+ proxyRelay.Stop()
|
||||
}
|
||||
- }()
|
||||
+ return nil
|
||||
+ })
|
||||
|
||||
- // Keep the main goroutine running
|
||||
- sigCh := make(chan os.Signal, 1)
|
||||
- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
- <-sigCh
|
||||
- logger.Info("Shutting down servers...")
|
||||
+ // Wait for all goroutines to finish
|
||||
+ if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
+ logger.Error("Service exited with error: %v", err)
|
||||
+ } else if errors.Is(err, context.Canceled) {
|
||||
+ logger.Info("Context cancelled, shutting down")
|
||||
+ }
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
@@ -639,7 +673,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
- logger.Error(errMsg)
|
||||
+ logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -656,7 +690,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
- logger.Error(errMsg)
|
||||
+ logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
-func periodicBandwidthCheck(endpoint string) {
|
||||
+func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
- for range ticker.C {
|
||||
- if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
- logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
+ logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
+ }
|
||||
+ case <-ctx.Done():
|
||||
+ return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
diff --git a/relay/relay.go b/relay/relay.go
|
||||
index e74ed87..e3fef04 100644
|
||||
--- a/relay/relay.go
|
||||
+++ b/relay/relay.go
|
||||
@@ -1,6 +1,7 @@
|
||||
package relay
|
||||
|
||||
import (
|
||||
+ "context"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
@@ -112,6 +113,8 @@ type UDPProxyServer struct {
|
||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||
privateKey wgtypes.Key
|
||||
packetChan chan Packet
|
||||
+ ctx context.Context
|
||||
+ cancel context.CancelFunc
|
||||
|
||||
// Session tracking for WireGuard peers
|
||||
// Key format: "senderIndex:receiverIndex"
|
||||
@@ -123,14 +126,17 @@ type UDPProxyServer struct {
|
||||
ReachableAt string
|
||||
}
|
||||
|
||||
-// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||
-func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
+// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
+func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
+ ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 1000),
|
||||
ReachableAt: reachableAt,
|
||||
+ ctx: ctx,
|
||||
+ cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,17 +183,51 @@ func (s *UDPProxyServer) Start() error {
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) Stop() {
|
||||
- s.conn.Close()
|
||||
+ // Signal all background goroutines to stop
|
||||
+ if s.cancel != nil {
|
||||
+ s.cancel()
|
||||
+ }
|
||||
+ // Close listener to unblock reads
|
||||
+ if s.conn != nil {
|
||||
+ _ = s.conn.Close()
|
||||
+ }
|
||||
+ // Close all downstream UDP connections
|
||||
+ s.connections.Range(func(key, value interface{}) bool {
|
||||
+ if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
|
||||
+ _ = dc.conn.Close()
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ // Close packet channel to stop workers
|
||||
+ select {
|
||||
+ case <-s.ctx.Done():
|
||||
+ default:
|
||||
+ }
|
||||
+ close(s.packetChan)
|
||||
}
|
||||
|
||||
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||
func (s *UDPProxyServer) readPackets() {
|
||||
for {
|
||||
+ // Exit promptly if context is canceled
|
||||
+ select {
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ default:
|
||||
+ }
|
||||
buf := bufferPool.Get().([]byte)
|
||||
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
- logger.Error("Error reading UDP packet: %v", err)
|
||||
- continue
|
||||
+ // If we're shutting down, exit
|
||||
+ select {
|
||||
+ case <-s.ctx.Done():
|
||||
+ bufferPool.Put(buf[:1500])
|
||||
+ return
|
||||
+ default:
|
||||
+ logger.Error("Error reading UDP packet: %v", err)
|
||||
+ bufferPool.Put(buf[:1500])
|
||||
+ continue
|
||||
+ }
|
||||
}
|
||||
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||
}
|
||||
@@ -588,49 +628,67 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
// Add a cleanup method to periodically remove idle connections
|
||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.connections.Range(func(key, value interface{}) bool {
|
||||
- destConn := value.(*DestinationConn)
|
||||
- if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
- destConn.conn.Close()
|
||||
- s.connections.Delete(key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.connections.Range(func(key, value interface{}) bool {
|
||||
+ destConn := value.(*DestinationConn)
|
||||
+ if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
+ destConn.conn.Close()
|
||||
+ s.connections.Delete(key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle sessions
|
||||
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
- session := value.(*WireGuardSession)
|
||||
- if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
- s.wgSessions.Delete(key)
|
||||
- logger.Debug("Removed idle session: %s", key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
+ session := value.(*WireGuardSession)
|
||||
+ if now.Sub(session.LastSeen) > 15*time.Minute {
|
||||
+ s.wgSessions.Delete(key)
|
||||
+ logger.Debug("Removed idle session: %s", key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle proxy mappings
|
||||
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
- mapping := value.(ProxyMapping)
|
||||
- // Remove mappings that haven't been used in 30 minutes
|
||||
- if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
- s.proxyMappings.Delete(key)
|
||||
- logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
+ mapping := value.(ProxyMapping)
|
||||
+ // Remove mappings that haven't been used in 30 minutes
|
||||
+ if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
+ s.proxyMappings.Delete(key)
|
||||
+ logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -943,23 +1001,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
- for range ticker.C {
|
||||
- now := time.Now()
|
||||
- s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
- pattern := value.(*CommunicationPattern)
|
||||
-
|
||||
- // Get the most recent activity
|
||||
- lastActivity := pattern.LastFromClient
|
||||
- if pattern.LastFromDest.After(lastActivity) {
|
||||
- lastActivity = pattern.LastFromDest
|
||||
- }
|
||||
+ defer ticker.Stop()
|
||||
+ for {
|
||||
+ select {
|
||||
+ case <-ticker.C:
|
||||
+ now := time.Now()
|
||||
+ s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
+ pattern := value.(*CommunicationPattern)
|
||||
+
|
||||
+ // Get the most recent activity
|
||||
+ lastActivity := pattern.LastFromClient
|
||||
+ if pattern.LastFromDest.After(lastActivity) {
|
||||
+ lastActivity = pattern.LastFromDest
|
||||
+ }
|
||||
|
||||
- // Remove patterns that haven't had activity in 20 minutes
|
||||
- if now.Sub(lastActivity) > 20*time.Minute {
|
||||
- s.commPatterns.Delete(key)
|
||||
- logger.Debug("Removed idle communication pattern: %s", key)
|
||||
- }
|
||||
- return true
|
||||
- })
|
||||
+ // Remove patterns that haven't had activity in 20 minutes
|
||||
+ if now.Sub(lastActivity) > 20*time.Minute {
|
||||
+ s.commPatterns.Delete(key)
|
||||
+ logger.Debug("Removed idle communication pattern: %s", key)
|
||||
+ }
|
||||
+ return true
|
||||
+ })
|
||||
+ case <-s.ctx.Done():
|
||||
+ return
|
||||
+ }
|
||||
}
|
||||
}
|
||||
75
main.go
75
main.go
@@ -2,7 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -21,6 +23,7 @@ import (
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
"github.com/vishvananda/netlink"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.zx2c4.com/wireguard/wgctrl"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
)
|
||||
@@ -217,6 +220,10 @@ func main() {
|
||||
logger.Init()
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Base context for the application; cancel on SIGINT/SIGTERM
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
|
||||
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
|
||||
if reachableAt != "" && listenAddr == "" {
|
||||
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
|
||||
@@ -324,10 +331,16 @@ func main() {
|
||||
// Ensure the WireGuard peers exist
|
||||
ensureWireguardPeers(wgconfig.Peers)
|
||||
|
||||
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
|
||||
// Child error group derived from base context
|
||||
group, groupCtx := errgroup.WithContext(ctx)
|
||||
|
||||
// Periodic bandwidth reporting
|
||||
group.Go(func() error {
|
||||
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
|
||||
})
|
||||
|
||||
// Start the UDP proxy server
|
||||
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
|
||||
proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt)
|
||||
err = proxyRelay.Start()
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to start UDP proxy server: %v", err)
|
||||
@@ -371,18 +384,39 @@ func main() {
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
// Run HTTP server in a goroutine
|
||||
go func() {
|
||||
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||
logger.Error("HTTP server failed: %v", err)
|
||||
// HTTP server with graceful shutdown on context cancel
|
||||
server := &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: nil,
|
||||
}
|
||||
group.Go(func() error {
|
||||
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
|
||||
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return err
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
})
|
||||
group.Go(func() error {
|
||||
<-groupCtx.Done()
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = server.Shutdown(shutdownCtx)
|
||||
// Stop background components as the context is canceled
|
||||
if proxySNI != nil {
|
||||
_ = proxySNI.Stop()
|
||||
}
|
||||
if proxyRelay != nil {
|
||||
proxyRelay.Stop()
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
// Keep the main goroutine running
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigCh
|
||||
logger.Info("Shutting down servers...")
|
||||
// Wait for all goroutines to finish
|
||||
if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
|
||||
logger.Error("Service exited with error: %v", err)
|
||||
} else if errors.Is(err, context.Canceled) {
|
||||
logger.Info("Context cancelled, shutting down")
|
||||
}
|
||||
}
|
||||
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
@@ -639,7 +673,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := addCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -656,7 +690,7 @@ func ensureMSSClamping() error {
|
||||
if out, err := checkCmd.CombinedOutput(); err != nil {
|
||||
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
|
||||
chain, err, string(out))
|
||||
logger.Error(errMsg)
|
||||
logger.Error("%s", errMsg)
|
||||
errors = append(errors, fmt.Errorf("%s", errMsg))
|
||||
continue
|
||||
}
|
||||
@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
}
|
||||
|
||||
func periodicBandwidthCheck(endpoint string) {
|
||||
func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := reportPeerBandwidth(endpoint); err != nil {
|
||||
logger.Info("Failed to report peer bandwidth: %v", err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
173
relay/relay.go
173
relay/relay.go
@@ -2,6 +2,7 @@ package relay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -141,6 +142,8 @@ type UDPProxyServer struct {
|
||||
connections sync.Map // map[string]*DestinationConn where key is destination "ip:port"
|
||||
privateKey wgtypes.Key
|
||||
packetChan chan Packet
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
|
||||
// Session tracking for WireGuard peers
|
||||
// Key format: "senderIndex:receiverIndex"
|
||||
@@ -152,14 +155,17 @@ type UDPProxyServer struct {
|
||||
ReachableAt string
|
||||
}
|
||||
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel.
|
||||
func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
// NewUDPProxyServer initializes the server with a buffered packet channel and derived context.
|
||||
func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer {
|
||||
ctx, cancel := context.WithCancel(parentCtx)
|
||||
return &UDPProxyServer{
|
||||
addr: addr,
|
||||
serverURL: serverURL,
|
||||
privateKey: privateKey,
|
||||
packetChan: make(chan Packet, 1000),
|
||||
ReachableAt: reachableAt,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -206,19 +212,51 @@ func (s *UDPProxyServer) Start() error {
|
||||
}
|
||||
|
||||
func (s *UDPProxyServer) Stop() {
|
||||
s.conn.Close()
|
||||
// Signal all background goroutines to stop
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
// Close listener to unblock reads
|
||||
if s.conn != nil {
|
||||
_ = s.conn.Close()
|
||||
}
|
||||
// Close all downstream UDP connections
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
if dc, ok := value.(*DestinationConn); ok && dc.conn != nil {
|
||||
_ = dc.conn.Close()
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Close packet channel to stop workers
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
default:
|
||||
}
|
||||
close(s.packetChan)
|
||||
}
|
||||
|
||||
// readPackets continuously reads from the UDP socket and pushes packets into the channel.
|
||||
func (s *UDPProxyServer) readPackets() {
|
||||
for {
|
||||
// Exit promptly if context is canceled
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
buf := bufferPool.Get().([]byte)
|
||||
n, remoteAddr, err := s.conn.ReadFromUDP(buf)
|
||||
if err != nil {
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
// Return buffer to pool on read error to avoid leaks
|
||||
bufferPool.Put(buf[:1500])
|
||||
continue
|
||||
// If we're shutting down, exit
|
||||
select {
|
||||
case <-s.ctx.Done():
|
||||
bufferPool.Put(buf[:1500])
|
||||
return
|
||||
default:
|
||||
logger.Error("Error reading UDP packet: %v", err)
|
||||
bufferPool.Put(buf[:1500])
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n}
|
||||
}
|
||||
@@ -617,50 +655,69 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd
|
||||
// Add a cleanup method to periodically remove idle connections
|
||||
func (s *UDPProxyServer) cleanupIdleConnections() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.connections.Range(func(key, value interface{}) bool {
|
||||
destConn := value.(*DestinationConn)
|
||||
if now.Sub(destConn.lastUsed) > 10*time.Minute {
|
||||
destConn.conn.Close()
|
||||
s.connections.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle sessions
|
||||
func (s *UDPProxyServer) cleanupIdleSessions() {
|
||||
ticker := time.NewTicker(5 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
// Use thread-safe method to read LastSeen
|
||||
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.wgSessions.Range(func(key, value interface{}) bool {
|
||||
session := value.(*WireGuardSession)
|
||||
// Use thread-safe method to read LastSeen
|
||||
if now.Sub(session.GetLastSeen()) > 15*time.Minute {
|
||||
s.wgSessions.Delete(key)
|
||||
logger.Debug("Removed idle session: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// New method to periodically remove idle proxy mappings
|
||||
func (s *UDPProxyServer) cleanupIdleProxyMappings() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.proxyMappings.Range(func(key, value interface{}) bool {
|
||||
mapping := value.(ProxyMapping)
|
||||
// Remove mappings that haven't been used in 30 minutes
|
||||
if now.Sub(mapping.LastUsed) > 30*time.Minute {
|
||||
s.proxyMappings.Delete(key)
|
||||
logger.Debug("Removed idle proxy mapping: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -972,23 +1029,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) {
|
||||
// cleanupIdleCommunicationPatterns periodically removes idle communication patterns
|
||||
func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() {
|
||||
ticker := time.NewTicker(10 * time.Minute)
|
||||
for range ticker.C {
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
now := time.Now()
|
||||
s.commPatterns.Range(func(key, value interface{}) bool {
|
||||
pattern := value.(*CommunicationPattern)
|
||||
|
||||
// Get the most recent activity
|
||||
lastActivity := pattern.LastFromClient
|
||||
if pattern.LastFromDest.After(lastActivity) {
|
||||
lastActivity = pattern.LastFromDest
|
||||
}
|
||||
// Get the most recent activity
|
||||
lastActivity := pattern.LastFromClient
|
||||
if pattern.LastFromDest.After(lastActivity) {
|
||||
lastActivity = pattern.LastFromDest
|
||||
}
|
||||
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
// Remove patterns that haven't had activity in 20 minutes
|
||||
if now.Sub(lastActivity) > 20*time.Minute {
|
||||
s.commPatterns.Delete(key)
|
||||
logger.Debug("Removed idle communication pattern: %s", key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
case <-s.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user