feat(cli/proxy): add OTLP timeout flag and make proxy metrics resilient

This commit is contained in:
Marc Schäfer
2026-05-04 00:12:24 +02:00
parent 191b4fa26a
commit cda6fa6772
2 changed files with 26 additions and 11 deletions

17
main.go
View File

@@ -175,6 +175,7 @@ func main() {
otelMetricsEndpoint string
otelMetricsInsecure bool
otelMetricsExportInterval time.Duration
otelMetricsTimeout time.Duration
)
interfaceName = os.Getenv("INTERFACE")
@@ -229,6 +230,14 @@ func main() {
log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2)
}
}
otelMetricsTimeout = 10 * time.Second // default
if v := os.Getenv("OTEL_METRICS_TIMEOUT"); v != "" {
if d, err2 := time.ParseDuration(v); err2 == nil {
otelMetricsTimeout = d
} else {
log.Printf("WARN: invalid OTEL_METRICS_TIMEOUT=%q: %v", v, err2)
}
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
@@ -322,6 +331,7 @@ func main() {
flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)")
flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection")
flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes")
flag.DurationVar(&otelMetricsTimeout, "otel-metrics-timeout", otelMetricsTimeout, "Timeout for OTLP exporter setup")
flag.Parse()
@@ -347,6 +357,7 @@ func main() {
Endpoint: otelMetricsEndpoint,
Insecure: otelMetricsInsecure,
ExportInterval: otelMetricsExportInterval,
Timeout: otelMetricsTimeout,
},
ServiceName: "gerbil",
ServiceVersion: "1.0.0",
@@ -543,6 +554,8 @@ func main() {
// Register metrics endpoint only for Prometheus backend.
// OTel backend pushes to a collector; no /metrics endpoint needed.
// Note: metricsPath is registered directly without httpMetricsMiddleware to prevent infinite recursion.
// The metricsHandler must not be wrapped by the middleware, as it would observe its own observation calls.
if metricsHandler != nil {
http.Handle(metricsPath, metricsHandler)
logger.Info("Metrics endpoint enabled at %s", metricsPath)
@@ -1162,10 +1175,12 @@ func removePeerInternal(publicKey string) error {
// Get current peer info before removing to clear relay connections and bandwidth limits
var wgIPs []string
allowedIPsCount := 0
device, err := wgClient.Device(interfaceName)
if err == nil {
for _, peer := range device.Peers {
if peer.PublicKey.String() == publicKey {
allowedIPsCount = len(peer.AllowedIPs)
// Extract WireGuard IPs from this peer's allowed IPs
for _, allowedIP := range peer.AllowedIPs {
wgIPs = append(wgIPs, allowedIP.IP.String())
@@ -1208,7 +1223,7 @@ func removePeerInternal(publicKey string) error {
// Record metrics
metrics.RecordPeersTotal(interfaceName, -1)
metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs)))
metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(allowedIPsCount))
return nil
}

View File

@@ -548,7 +548,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
logger.Debug("SNI extraction failed: %v", err)
return
}
metrics.RecordProxyTLSHandshake(hostname, time.Since(clientHelloStart).Seconds())
metrics.RecordProxyTLSHandshake(time.Since(clientHelloStart).Seconds())
if hostname == "" {
log.Println("No SNI hostname found")
@@ -596,8 +596,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
defer targetConn.Close()
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
metrics.RecordActiveProxyConnection(hostname, 1)
defer metrics.RecordActiveProxyConnection(hostname, -1)
metrics.RecordActiveProxyConnection(1)
defer metrics.RecordActiveProxyConnection(-1)
// Send PROXY protocol header if enabled
if p.proxyProtocol {
@@ -655,7 +655,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check local overrides first
if _, isOverride := p.localOverrides[hostname]; isOverride {
logger.Debug("Local override matched for hostname: %s", hostname)
metrics.RecordProxyRouteLookup("local_override", hostname)
metrics.RecordProxyRouteLookup("local_override")
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
@@ -668,7 +668,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
_, isLocal := p.localSNIs[hostname]
p.localSNIsLock.RUnlock()
if isLocal {
metrics.RecordProxyRouteLookup("local", hostname)
metrics.RecordProxyRouteLookup("local")
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
@@ -679,16 +679,16 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check cache first
if cached, found := p.cache.Get(hostname); found {
if cached == nil {
metrics.RecordProxyRouteLookup("cached_not_found", hostname)
metrics.RecordProxyRouteLookup("cached_not_found")
return nil, nil // Cached negative result
}
logger.Debug("Cache hit for hostname: %s", hostname)
metrics.RecordProxyRouteLookup("cache_hit", hostname)
metrics.RecordProxyRouteLookup("cache_hit")
return cached.(*RouteRecord), nil
}
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
metrics.RecordProxyRouteLookup("cache_miss", hostname)
metrics.RecordProxyRouteLookup("cache_miss")
// Query API with timeout
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
@@ -822,7 +822,7 @@ func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, client
}()
bytesCopied, err := io.CopyBuffer(targetConn, clientReader, *bufPtr)
metrics.RecordProxyBytesTransmitted(hostname, "client_to_target", bytesCopied)
metrics.RecordProxyBytesTransmitted("client_to_target", bytesCopied)
if err != nil && err != io.EOF {
logger.Debug("Copy client->target error: %v", err)
}
@@ -842,7 +842,7 @@ func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, client
}()
bytesCopied, err := io.CopyBuffer(clientConn, targetConn, *bufPtr)
metrics.RecordProxyBytesTransmitted(hostname, "target_to_client", bytesCopied)
metrics.RecordProxyBytesTransmitted("target_to_client", bytesCopied)
if err != nil && err != io.EOF {
logger.Debug("Copy target->client error: %v", err)
}