From cda6fa677295aa48ce1781092675b78c35b793b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:12:24 +0200 Subject: [PATCH] feat(cli/proxy): add OTLP timeout flag and make proxy metrics resilient --- main.go | 17 ++++++++++++++++- proxy/proxy.go | 20 ++++++++++---------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index 62bfe7c..2a31c2b 100644 --- a/main.go +++ b/main.go @@ -175,6 +175,7 @@ func main() { otelMetricsEndpoint string otelMetricsInsecure bool otelMetricsExportInterval time.Duration + otelMetricsTimeout time.Duration ) interfaceName = os.Getenv("INTERFACE") @@ -229,6 +230,14 @@ func main() { log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2) } } + otelMetricsTimeout = 10 * time.Second // default + if v := os.Getenv("OTEL_METRICS_TIMEOUT"); v != "" { + if d, err2 := time.ParseDuration(v); err2 == nil { + otelMetricsTimeout = d + } else { + log.Printf("WARN: invalid OTEL_METRICS_TIMEOUT=%q: %v", v, err2) + } + } if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -322,6 +331,7 @@ func main() { flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)") flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection") flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes") + flag.DurationVar(&otelMetricsTimeout, "otel-metrics-timeout", otelMetricsTimeout, "Timeout for OTLP exporter setup") flag.Parse() @@ -347,6 +357,7 @@ func main() { Endpoint: otelMetricsEndpoint, Insecure: otelMetricsInsecure, ExportInterval: otelMetricsExportInterval, + Timeout: otelMetricsTimeout, }, ServiceName: "gerbil", ServiceVersion: "1.0.0", @@ -543,6 +554,8 @@ func main() { // Register metrics endpoint only for Prometheus backend. // OTel backend pushes to a collector; no /metrics endpoint needed. + // Note: metricsPath is registered directly without httpMetricsMiddleware to prevent infinite recursion. + // The metricsHandler must not be wrapped by the middleware, as it would observe its own observation calls. if metricsHandler != nil { http.Handle(metricsPath, metricsHandler) logger.Info("Metrics endpoint enabled at %s", metricsPath) @@ -1162,10 +1175,12 @@ func removePeerInternal(publicKey string) error { // Get current peer info before removing to clear relay connections and bandwidth limits var wgIPs []string + allowedIPsCount := 0 device, err := wgClient.Device(interfaceName) if err == nil { for _, peer := range device.Peers { if peer.PublicKey.String() == publicKey { + allowedIPsCount = len(peer.AllowedIPs) // Extract WireGuard IPs from this peer's allowed IPs for _, allowedIP := range peer.AllowedIPs { wgIPs = append(wgIPs, allowedIP.IP.String()) @@ -1208,7 +1223,7 @@ func removePeerInternal(publicKey string) error { // Record metrics metrics.RecordPeersTotal(interfaceName, -1) - metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs))) + metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(allowedIPsCount)) return nil } diff --git a/proxy/proxy.go b/proxy/proxy.go index 71cf4ed..9b46e10 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -548,7 +548,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { logger.Debug("SNI extraction failed: %v", err) return } - metrics.RecordProxyTLSHandshake(hostname, time.Since(clientHelloStart).Seconds()) + metrics.RecordProxyTLSHandshake(time.Since(clientHelloStart).Seconds()) if hostname == "" { log.Println("No SNI hostname found") @@ -596,8 +596,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { defer targetConn.Close() logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort) - metrics.RecordActiveProxyConnection(hostname, 1) - defer metrics.RecordActiveProxyConnection(hostname, -1) + metrics.RecordActiveProxyConnection(1) + defer metrics.RecordActiveProxyConnection(-1) // Send PROXY protocol header if enabled if p.proxyProtocol { @@ -655,7 +655,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check local overrides first if _, isOverride := p.localOverrides[hostname]; isOverride { logger.Debug("Local override matched for hostname: %s", hostname) - metrics.RecordProxyRouteLookup("local_override", hostname) + metrics.RecordProxyRouteLookup("local_override") return &RouteRecord{ Hostname: hostname, TargetHost: p.localProxyAddr, @@ -668,7 +668,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { _, isLocal := p.localSNIs[hostname] p.localSNIsLock.RUnlock() if isLocal { - metrics.RecordProxyRouteLookup("local", hostname) + metrics.RecordProxyRouteLookup("local") return &RouteRecord{ Hostname: hostname, TargetHost: p.localProxyAddr, @@ -679,16 +679,16 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check cache first if cached, found := p.cache.Get(hostname); found { if cached == nil { - metrics.RecordProxyRouteLookup("cached_not_found", hostname) + metrics.RecordProxyRouteLookup("cached_not_found") return nil, nil // Cached negative result } logger.Debug("Cache hit for hostname: %s", hostname) - metrics.RecordProxyRouteLookup("cache_hit", hostname) + metrics.RecordProxyRouteLookup("cache_hit") return cached.(*RouteRecord), nil } logger.Debug("Cache miss for hostname: %s, querying API", hostname) - metrics.RecordProxyRouteLookup("cache_miss", hostname) + metrics.RecordProxyRouteLookup("cache_miss") // Query API with timeout ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) @@ -822,7 +822,7 @@ func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, client }() bytesCopied, err := io.CopyBuffer(targetConn, clientReader, *bufPtr) - metrics.RecordProxyBytesTransmitted(hostname, "client_to_target", bytesCopied) + metrics.RecordProxyBytesTransmitted("client_to_target", bytesCopied) if err != nil && err != io.EOF { logger.Debug("Copy client->target error: %v", err) } @@ -842,7 +842,7 @@ func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, client }() bytesCopied, err := io.CopyBuffer(clientConn, targetConn, *bufPtr) - metrics.RecordProxyBytesTransmitted(hostname, "target_to_client", bytesCopied) + metrics.RecordProxyBytesTransmitted("target_to_client", bytesCopied) if err != nil && err != io.EOF { logger.Debug("Copy target->client error: %v", err) }