enhancement: base context + errgroup; propagate cancellation; graceful shutdown

- main: add base context via signal.NotifyContext; establish errgroup and use it to supervise background tasks; convert ticker to context-aware periodicBandwidthCheck; run HTTP server under errgroup and add graceful shutdown; treat context.Canceled as normal exit
- relay: thread parent context through UDPProxyServer; add cancel func; make packet reader, workers, and cleanup tickers exit on ctx.Done; Stop cancels, closes listener and downstream UDP connections, and closes packet channel to drain workers
- proxy: drop earlier parent context hook for SNI proxy per review; rely on existing Stop() for graceful shutdown

Benefits:
- unified lifecycle and deterministic shutdown across components
- prevents leaked goroutines/tickers and closes sockets cleanly
- consolidated error handling via g.Wait(), with context cancellation treated as non-error
- sets foundation for child errgroups and future structured concurrency
This commit is contained in:
Laurence
2025-11-16 05:59:34 +00:00
parent 709df6db3e
commit 697f4131e7
2 changed files with 173 additions and 70 deletions

75
main.go
View File

@@ -2,7 +2,9 @@ package main
import (
"bytes"
"context"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
@@ -21,6 +23,7 @@ import (
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
"github.com/vishvananda/netlink"
"golang.org/x/sync/errgroup"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
@@ -217,6 +220,10 @@ func main() {
logger.Init()
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// Base context for the application; cancel on SIGINT/SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
// try to parse as http://host:port and set the listenAddr to the :port from this reachableAt.
if reachableAt != "" && listenAddr == "" {
if strings.HasPrefix(reachableAt, "http://") || strings.HasPrefix(reachableAt, "https://") {
@@ -324,10 +331,16 @@ func main() {
// Ensure the WireGuard peers exist
ensureWireguardPeers(wgconfig.Peers)
go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth")
// Child error group derived from base context
group, groupCtx := errgroup.WithContext(ctx)
// Periodic bandwidth reporting
group.Go(func() error {
return periodicBandwidthCheck(groupCtx, remoteConfigURL+"/gerbil/receive-bandwidth")
})
// Start the UDP proxy server
proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt)
proxyRelay = relay.NewUDPProxyServer(groupCtx, ":21820", remoteConfigURL, key, reachableAt)
err = proxyRelay.Start()
if err != nil {
logger.Fatal("Failed to start UDP proxy server: %v", err)
@@ -371,18 +384,39 @@ func main() {
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
logger.Info("Starting HTTP server on %s", listenAddr)
// Run HTTP server in a goroutine
go func() {
if err := http.ListenAndServe(listenAddr, nil); err != nil {
logger.Error("HTTP server failed: %v", err)
// HTTP server with graceful shutdown on context cancel
server := &http.Server{
Addr: listenAddr,
Handler: nil,
}
group.Go(func() error {
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
}()
return nil
})
group.Go(func() error {
<-groupCtx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = server.Shutdown(shutdownCtx)
// Stop background components as the context is canceled
if proxySNI != nil {
_ = proxySNI.Stop()
}
if proxyRelay != nil {
proxyRelay.Stop()
}
return nil
})
// Keep the main goroutine running
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
logger.Info("Shutting down servers...")
// Wait for all goroutines to finish
if err := group.Wait(); err != nil && !errors.Is(err, context.Canceled) {
logger.Error("Service exited with error: %v", err)
} else if errors.Is(err, context.Canceled) {
logger.Info("Context cancelled, shutting down")
}
}
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
@@ -639,7 +673,7 @@ func ensureMSSClamping() error {
if out, err := addCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Failed to add MSS clamping rule for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -656,7 +690,7 @@ func ensureMSSClamping() error {
if out, err := checkCmd.CombinedOutput(); err != nil {
errMsg := fmt.Sprintf("Rule verification failed for chain %s: %v (output: %s)",
chain, err, string(out))
logger.Error(errMsg)
logger.Error("%s", errMsg)
errors = append(errors, fmt.Errorf("%s", errMsg))
continue
}
@@ -977,13 +1011,18 @@ func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) {
})
}
func periodicBandwidthCheck(endpoint string) {
func periodicBandwidthCheck(ctx context.Context, endpoint string) error {
ticker := time.NewTicker(10 * time.Second)
defer ticker.Stop()
for range ticker.C {
if err := reportPeerBandwidth(endpoint); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
for {
select {
case <-ticker.C:
if err := reportPeerBandwidth(endpoint); err != nil {
logger.Info("Failed to report peer bandwidth: %v", err)
}
case <-ctx.Done():
return ctx.Err()
}
}
}