diff --git a/main.go b/main.go index 7a99c4d..61c186f 100644 --- a/main.go +++ b/main.go @@ -2,7 +2,9 @@ package main import ( "bytes" + "context" "encoding/json" + "errors" "flag" "fmt" "io" @@ -21,6 +23,7 @@ import ( "github.com/fosrl/gerbil/proxy" "github.com/fosrl/gerbil/relay" "github.com/vishvananda/netlink" + "golang.org/x/sync/errgroup" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -217,6 +220,10 @@ func main() { logger.Init() logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + // Base context for the application; cancel on SIGINT/SIGTERM + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + // try to parse as http://host:port and set the listenAddr to the :port from this reachableAt. if reachableAt != "" && listenAddr == "" { if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") { @@ -324,10 +331,16 @@ func main() { // Ensure the WireGuard peers exist ensureWireguardPeers(wgconfig.Peers) - go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") + // Child error group derived from base context + group, groupCtx := errgroup.WithContext(ctx) + + // Periodic bandwidth reporting + group.Go(func() error { + return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth") + }) // Start the UDP proxy server - proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) + proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt) err = proxyRelay.Start() if err != nil { logger.Fatal("Failed to start UDP proxy server: %v", err) @@ -371,18 +384,39 @@ func main() { http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) logger.Info("Starting HTTP server on %s", listenAddr) - // Run HTTP server in a goroutine - go func() { - if err := http.ListenAndServe(listenAddr, nil); err != nil { - logger.Error("HTTP server failed: %v", err) + // HTTP server with graceful shutdown on context cancel + server := &http.Server{ + Addr: listenAddr, + Handler: nil, + } + group.Go(func() error { + // http.ErrServerClosed is returned on graceful shutdown; not an error for us + if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + return err } - }() + return nil + }) + group.Go(func() error { + <-groupCtx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _ = server.Shutdown(shutdownCtx) + // Stop background components as the context is canceled + if proxySNI != nil { + _ = proxySNI.Stop() + } + if proxyRelay != nil { + proxyRelay.Stop() + } + return nil + }) - // Keep the main goroutine running - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) - <-sigCh - logger.Info("Shutting down servers...") + // Wait for all goroutines to finish + if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) { + logger.Error("Service exited with error: %v", err) + } else if errors.Is(err, context.Canceled) { + logger.Info("Context cancelled, shutting down") + } } func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { @@ -639,7 +673,7 @@ func ensureMSSClamping() error { if out, err := addCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)", chain, err, string(out)) - logger.Error(errMsg) + logger.Error("%s", errMsg) errors = append(errors, fmt.Errorf("%s", errMsg)) continue } @@ -656,7 +690,7 @@ func ensureMSSClamping() error { if out, err := checkCmd.CombinedOutput(); err != nil { errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)", chain, err, string(out)) - logger.Error(errMsg) + logger.Error("%s", errMsg) errors = append(errors, fmt.Errorf("%s", errMsg)) continue } @@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { }) } -func periodicBandwidthCheck(endpoint string) { +func periodicBandwidthCheck(ctx context.Context, endpoint string) error { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() - for range ticker.C { - if err := reportPeerBandwidth(endpoint); err != nil { - logger.Info("Failed to report peer bandwidth: %v", err) + for { + select { + case <-ticker.C: + if err := reportPeerBandwidth(endpoint); err != nil { + logger.Info("Failed to report peer bandwidth: %v", err) + } + case <-ctx.Done(): + return ctx.Err() } } } diff --git a/relay/relay.go b/relay/relay.go index e74ed87..e3fef04 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -1,6 +1,7 @@ package relay import ( + "context" "bytes" "encoding/binary" "encoding/json" @@ -112,6 +113,8 @@ type UDPProxyServer struct { connections sync.Map // map[string]*DestinationConn where key is destination "ip:port" privateKey wgtypes.Key packetChan chan Packet + ctx context.Context + cancel context.CancelFunc // Session tracking for WireGuard peers // Key format: "senderIndex:receiverIndex" @@ -123,14 +126,17 @@ type UDPProxyServer struct { ReachableAt string } -// NewUDPProxyServer initializes the server with a buffered packet channel. -func NewUDPProxyServer(addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { +// NewUDPProxyServer initializes the server with a buffered packet channel and derived context. +func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privateKey wgtypes.Key, reachableAt string) *UDPProxyServer { + ctx, cancel := context.WithCancel(parentCtx) return &UDPProxyServer{ addr: addr, serverURL: serverURL, privateKey: privateKey, packetChan: make(chan Packet, 1000), ReachableAt: reachableAt, + ctx: ctx, + cancel: cancel, } } @@ -177,17 +183,51 @@ func (s *UDPProxyServer) Start() error { } func (s *UDPProxyServer) Stop() { - s.conn.Close() + // Signal all background goroutines to stop + if s.cancel != nil { + s.cancel() + } + // Close listener to unblock reads + if s.conn != nil { + _ = s.conn.Close() + } + // Close all downstream UDP connections + s.connections.Range(func(key, value interface{}) bool { + if dc, ok := value.(*DestinationConn); ok && dc.conn != nil { + _ = dc.conn.Close() + } + return true + }) + // Close packet channel to stop workers + select { + case <-s.ctx.Done(): + default: + } + close(s.packetChan) } // readPackets continuously reads from the UDP socket and pushes packets into the channel. func (s *UDPProxyServer) readPackets() { for { + // Exit promptly if context is canceled + select { + case <-s.ctx.Done(): + return + default: + } buf := bufferPool.Get().([]byte) n, remoteAddr, err := s.conn.ReadFromUDP(buf) if err != nil { - logger.Error("Error reading UDP packet: %v", err) - continue + // If we're shutting down, exit + select { + case <-s.ctx.Done(): + bufferPool.Put(buf[:1500]) + return + default: + logger.Error("Error reading UDP packet: %v", err) + bufferPool.Put(buf[:1500]) + continue + } } s.packetChan <- Packet{data: buf[:n], remoteAddr: remoteAddr, n: n} } @@ -588,49 +628,67 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd // Add a cleanup method to periodically remove idle connections func (s *UDPProxyServer) cleanupIdleConnections() { ticker := time.NewTicker(5 * time.Minute) - for range ticker.C { - now := time.Now() - s.connections.Range(func(key, value interface{}) bool { - destConn := value.(*DestinationConn) - if now.Sub(destConn.lastUsed) > 10*time.Minute { - destConn.conn.Close() - s.connections.Delete(key) - } - return true - }) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.connections.Range(func(key, value interface{}) bool { + destConn := value.(*DestinationConn) + if now.Sub(destConn.lastUsed) > 10*time.Minute { + destConn.conn.Close() + s.connections.Delete(key) + } + return true + }) + case <-s.ctx.Done(): + return + } } } // New method to periodically remove idle sessions func (s *UDPProxyServer) cleanupIdleSessions() { ticker := time.NewTicker(5 * time.Minute) - for range ticker.C { - now := time.Now() - s.wgSessions.Range(func(key, value interface{}) bool { - session := value.(*WireGuardSession) - if now.Sub(session.LastSeen) > 15*time.Minute { - s.wgSessions.Delete(key) - logger.Debug("Removed idle session: %s", key) - } - return true - }) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.wgSessions.Range(func(key, value interface{}) bool { + session := value.(*WireGuardSession) + if now.Sub(session.LastSeen) > 15*time.Minute { + s.wgSessions.Delete(key) + logger.Debug("Removed idle session: %s", key) + } + return true + }) + case <-s.ctx.Done(): + return + } } } // New method to periodically remove idle proxy mappings func (s *UDPProxyServer) cleanupIdleProxyMappings() { ticker := time.NewTicker(10 * time.Minute) - for range ticker.C { - now := time.Now() - s.proxyMappings.Range(func(key, value interface{}) bool { - mapping := value.(ProxyMapping) - // Remove mappings that haven't been used in 30 minutes - if now.Sub(mapping.LastUsed) > 30*time.Minute { - s.proxyMappings.Delete(key) - logger.Debug("Removed idle proxy mapping: %s", key) - } - return true - }) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.proxyMappings.Range(func(key, value interface{}) bool { + mapping := value.(ProxyMapping) + // Remove mappings that haven't been used in 30 minutes + if now.Sub(mapping.LastUsed) > 30*time.Minute { + s.proxyMappings.Delete(key) + logger.Debug("Removed idle proxy mapping: %s", key) + } + return true + }) + case <-s.ctx.Done(): + return + } } } @@ -943,23 +1001,29 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { // cleanupIdleCommunicationPatterns periodically removes idle communication patterns func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { ticker := time.NewTicker(10 * time.Minute) - for range ticker.C { - now := time.Now() - s.commPatterns.Range(func(key, value interface{}) bool { - pattern := value.(*CommunicationPattern) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.commPatterns.Range(func(key, value interface{}) bool { + pattern := value.(*CommunicationPattern) - // Get the most recent activity - lastActivity := pattern.LastFromClient - if pattern.LastFromDest.After(lastActivity) { - lastActivity = pattern.LastFromDest - } + // Get the most recent activity + lastActivity := pattern.LastFromClient + if pattern.LastFromDest.After(lastActivity) { + lastActivity = pattern.LastFromDest + } - // Remove patterns that haven't had activity in 20 minutes - if now.Sub(lastActivity) > 20*time.Minute { - s.commPatterns.Delete(key) - logger.Debug("Removed idle communication pattern: %s", key) - } - return true - }) + // Remove patterns that haven't had activity in 20 minutes + if now.Sub(lastActivity) > 20*time.Minute { + s.commPatterns.Delete(key) + logger.Debug("Removed idle communication pattern: %s", key) + } + return true + }) + case <-s.ctx.Done(): + return + } } }