From 526f9c8b4e59fe483db1450b9622d27dfc9168bc Mon Sep 17 00:00:00 2001 From: Owen Date: Sat, 6 Dec 2025 12:16:03 -0500 Subject: [PATCH] Remove diff --- 37.diff | 385 -------------------------------------------------------- 1 file changed, 385 deletions(-) delete mode 100644 37.diff diff --git a/37.diff b/37.diff deleted file mode 100644 index d80429c..0000000 --- a/37.diff +++ /dev/null @@ -1,385 +0,0 @@ -diff --git a/main.go b/main.go -index 7a99c4d..61c186f 100644 ---- a/main.go -+++ b/main.go -@@ -2,7 +2,9 @@ package main - - import ( - "bytes" -+ "context" - "encoding/json" -+ "errors" - "flag" - "fmt" - "io" -@@ -21,6 +23,7 @@ import ( - "github.com/fosrl/gerbil/proxy" - "github.com/fosrl/gerbil/relay" - "github.com/vishvananda/netlink" -+ "golang.org/x/sync/errgroup" - "golang.zx2c4.com/wireguard/wgctrl" - "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - ) -@@ -217,6 +220,10 @@ func main() { - logger.Init() - logger.GetLogger().SetLevel(parseLogLevel(logLevel)) - -+ // Base context for the application; cancel on SIGINT/SIGTERM -+ ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) -+ defer stop() -+ - // try to parse as http://host:port and set the listenAddr to the :port from this reachableAt. - if reachableAt != "" && listenAddr == "" { - if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") { -@@ -324,10 +331,16 @@ func main() { - // Ensure the WireGuard peers exist - ensureWireguardPeers(wgconfig.Peers) - -- go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") -+ // Child error group derived from base context -+ group, groupCtx := errgroup.WithContext(ctx) -+ -+ // Periodic bandwidth reporting -+ group.Go(func() error { -+ return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth") -+ }) - - // Start the UDP proxy server -- proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) -+ proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt) - err = proxyRelay.Start() - if err != nil { - logger.Fatal("Failed to start UDP proxy server: %v", err) -@@ -371,18 +384,39 @@ func main() { - http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) - logger.Info("Starting HTTP server on %s", listenAddr) - -- // Run HTTP server in a goroutine -- go func() { -- if err := http.ListenAndServe(listenAddr, nil); err != nil { -- logger.Error("HTTP server failed: %v", err) -+ // HTTP server with graceful shutdown on context cancel -+ server := &http.Server{ -+ Addr: listenAddr, -+ Handler: nil, -+ } -+ group.Go(func() error { -+ // http.ErrServerClosed is returned on graceful shutdown; not an error for us -+ if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { -+ return err -+ } -+ return nil -+ }) -+ group.Go(func() error { -+ <-groupCtx.Done() -+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) -+ defer cancel() -+ _ = server.Shutdown(shutdownCtx) -+ // Stop background components as the context is canceled -+ if proxySNI != nil { -+ _ = proxySNI.Stop() -+ } -+ if proxyRelay != nil { -+ proxyRelay.Stop() - } -- }() -+ return nil -+ }) - -- // Keep the main goroutine running -- sigCh := make(chan os.Signal, 1) -- signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) -- <-sigCh -- logger.Info("Shutting down servers...") -+ // Wait for all goroutines to finish -+ if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) { -+ logger.Error("Service exited with error: %v", err) -+ } else if errors.Is(err, context.Canceled) { -+ logger.Info("Context cancelled, shutting down") -+ } - } - - func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { -@@ -639,7 +673,7 @@ func ensureMSSClamping() error { - if out, err := addCmd.CombinedOutput(); err != nil { - errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", - chain, err, string(out)) -- logger.Error(errMsg) -+ logger.Error("%s", errMsg) - errors = append(errors, fmt.Errorf("%s", errMsg)) - continue - } -@@ -656,7 +690,7 @@ func ensureMSSClamping() error { - if out, err := checkCmd.CombinedOutput(); err != nil { - errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", - chain, err, string(out)) -- logger.Error(errMsg) -+ logger.Error("%s", errMsg) - errors = append(errors, fmt.Errorf("%s", errMsg)) - continue - } -@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { - }) - } - --func periodicBandwidthCheck(endpoint string) { -+func periodicBandwidthCheck(ctx context.Context, endpoint string) error { - ticker := time.NewTicker(10 * time.Second) - defer ticker.Stop() - -- for range ticker.C { -- if err := reportPeerBandwidth(endpoint); err != nil { -- logger.Info("Failed to report peer bandwidth: %v", err) -+ for { -+ select { -+ case <-ticker.C: -+ if err := reportPeerBandwidth(endpoint); err != nil { -+ logger.Info("Failed to report peer bandwidth: %v", err) -+ } -+ case <-ctx.Done(): -+ return ctx.Err() - } - } - } -diff --git a/relay/relay.go b/relay/relay.go -index e74ed87..e3fef04 100644 ---- a/relay/relay.go -+++ b/relay/relay.go -@@ -1,6 +1,7 @@ - package relay - - import ( -+ "context" - "bytes" - "encoding/binary" - "encoding/json" -@@ -112,6 +113,8 @@ type UDPProxyServer struct { - connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" - privateKey wgtypes.Key - packetChan chan Packet -+ ctx context.Context -+ cancel context.CancelFunc - - // Session tracking for WireGuard peers - // Key format: "senderIndex:receiverIndex" -@@ -123,14 +126,17 @@ type UDPProxyServer struct { - ReachableAt string - } - --// NewUDPProxyServer initializes the server with a buffered packet channel. --func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { -+// NewUDPProxyServer initializes the server with a buffered packet channel and derived context. -+func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { -+ ctx, cancel := context.WithCancel(parentCtx) - return &UDPProxyServer{ - addr: addr, - serverURL: serverURL, - privateKey: privateKey, - packetChan: make(chan Packet, 1000), - ReachableAt: reachableAt, -+ ctx: ctx, -+ cancel: cancel, - } - } - -@@ -177,17 +183,51 @@ func (s *UDPProxyServer) Start() error { - } - - func (s *UDPProxyServer) Stop() { -- s.conn.Close() -+ // Signal all background goroutines to stop -+ if s.cancel != nil { -+ s.cancel() -+ } -+ // Close listener to unblock reads -+ if s.conn != nil { -+ _ = s.conn.Close() -+ } -+ // Close all downstream UDP connections -+ s.connections.Range(func(key, value interface{}) bool { -+ if dc, ok := value.(*DestinationConn); ok && dc.conn != nil { -+ _ = dc.conn.Close() -+ } -+ return true -+ }) -+ // Close packet channel to stop workers -+ select { -+ case <-s.ctx.Done(): -+ default: -+ } -+ close(s.packetChan) - } - - // readPackets continuously reads from the UDP socket and pushes packets into the channel. - func (s *UDPProxyServer) readPackets() { - for { -+ // Exit promptly if context is canceled -+ select { -+ case <-s.ctx.Done(): -+ return -+ default: -+ } - buf := bufferPool.Get().([]byte) - n, remoteAddr, err := s.conn.ReadFromUDP(buf) - if err != nil { -- logger.Error("Error reading UDP packet: %v", err) -- continue -+ // If we're shutting down, exit -+ select { -+ case <-s.ctx.Done(): -+ bufferPool.Put(buf[:1500]) -+ return -+ default: -+ logger.Error("Error reading UDP packet: %v", err) -+ bufferPool.Put(buf[:1500]) -+ continue -+ } - } - s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n} - } -@@ -588,49 +628,67 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd - // Add a cleanup method to periodically remove idle connections - func (s *UDPProxyServer) cleanupIdleConnections() { - ticker := time.NewTicker(5 * time.Minute) -- for range ticker.C { -- now := time.Now() -- s.connections.Range(func(key, value interface{}) bool { -- destConn := value.(*DestinationConn) -- if now.Sub(destConn.lastUsed) > 10*time.Minute { -- destConn.conn.Close() -- s.connections.Delete(key) -- } -- return true -- }) -+ defer ticker.Stop() -+ for { -+ select { -+ case <-ticker.C: -+ now := time.Now() -+ s.connections.Range(func(key, value interface{}) bool { -+ destConn := value.(*DestinationConn) -+ if now.Sub(destConn.lastUsed) > 10*time.Minute { -+ destConn.conn.Close() -+ s.connections.Delete(key) -+ } -+ return true -+ }) -+ case <-s.ctx.Done(): -+ return -+ } - } - } - - // New method to periodically remove idle sessions - func (s *UDPProxyServer) cleanupIdleSessions() { - ticker := time.NewTicker(5 * time.Minute) -- for range ticker.C { -- now := time.Now() -- s.wgSessions.Range(func(key, value interface{}) bool { -- session := value.(*WireGuardSession) -- if now.Sub(session.LastSeen) > 15*time.Minute { -- s.wgSessions.Delete(key) -- logger.Debug("Removed idle session: %s", key) -- } -- return true -- }) -+ defer ticker.Stop() -+ for { -+ select { -+ case <-ticker.C: -+ now := time.Now() -+ s.wgSessions.Range(func(key, value interface{}) bool { -+ session := value.(*WireGuardSession) -+ if now.Sub(session.LastSeen) > 15*time.Minute { -+ s.wgSessions.Delete(key) -+ logger.Debug("Removed idle session: %s", key) -+ } -+ return true -+ }) -+ case <-s.ctx.Done(): -+ return -+ } - } - } - - // New method to periodically remove idle proxy mappings - func (s *UDPProxyServer) cleanupIdleProxyMappings() { - ticker := time.NewTicker(10 * time.Minute) -- for range ticker.C { -- now := time.Now() -- s.proxyMappings.Range(func(key, value interface{}) bool { -- mapping := value.(ProxyMapping) -- // Remove mappings that haven't been used in 30 minutes -- if now.Sub(mapping.LastUsed) > 30*time.Minute { -- s.proxyMappings.Delete(key) -- logger.Debug("Removed idle proxy mapping: %s", key) -- } -- return true -- }) -+ defer ticker.Stop() -+ for { -+ select { -+ case <-ticker.C: -+ now := time.Now() -+ s.proxyMappings.Range(func(key, value interface{}) bool { -+ mapping := value.(ProxyMapping) -+ // Remove mappings that haven't been used in 30 minutes -+ if now.Sub(mapping.LastUsed) > 30*time.Minute { -+ s.proxyMappings.Delete(key) -+ logger.Debug("Removed idle proxy mapping: %s", key) -+ } -+ return true -+ }) -+ case <-s.ctx.Done(): -+ return -+ } - } - } - -@@ -943,23 +1001,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { - // cleanupIdleCommunicationPatterns periodically removes idle communication patterns - func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { - ticker := time.NewTicker(10 * time.Minute) -- for range ticker.C { -- now := time.Now() -- s.commPatterns.Range(func(key, value interface{}) bool { -- pattern := value.(*CommunicationPattern) -- -- // Get the most recent activity -- lastActivity := pattern.LastFromClient -- if pattern.LastFromDest.After(lastActivity) { -- lastActivity = pattern.LastFromDest -- } -+ defer ticker.Stop() -+ for { -+ select { -+ case <-ticker.C: -+ now := time.Now() -+ s.commPatterns.Range(func(key, value interface{}) bool { -+ pattern := value.(*CommunicationPattern) -+ -+ // Get the most recent activity -+ lastActivity := pattern.LastFromClient -+ if pattern.LastFromDest.After(lastActivity) { -+ lastActivity = pattern.LastFromDest -+ } - -- // Remove patterns that haven't had activity in 20 minutes -- if now.Sub(lastActivity) > 20*time.Minute { -- s.commPatterns.Delete(key) -- logger.Debug("Removed idle communication pattern: %s", key) -- } -- return true -- }) -+ // Remove patterns that haven't had activity in 20 minutes -+ if now.Sub(lastActivity) > 20*time.Minute { -+ s.commPatterns.Delete(key) -+ logger.Debug("Removed idle communication pattern: %s", key) -+ } -+ return true -+ }) -+ case <-s.ctx.Done(): -+ return -+ } - } - }