From 4357ddf64b9b5c32409548b547c5a7af26e36429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Fri, 3 Apr 2026 15:57:53 +0200 Subject: [PATCH] Integrate metrics instrumentation across core services --- main.go | 220 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 207 insertions(+), 13 deletions(-) diff --git a/main.go b/main.go index 695de55..62bfe7c 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "syscall" "time" + "github.com/fosrl/gerbil/internal/metrics" "github.com/fosrl/gerbil/logger" "github.com/fosrl/gerbil/proxy" "github.com/fosrl/gerbil/relay" @@ -101,6 +102,35 @@ type UpdateDestinationsRequest struct { Destinations []relay.PeerDestination `json:"destinations"` } +// httpMetricsMiddleware wraps HTTP handlers with metrics tracking +func httpMetricsMiddleware(endpoint string, handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + // Create a response writer wrapper to capture status code + ww := &responseWriterWrapper{ResponseWriter: w, statusCode: http.StatusOK} + + // Call the actual handler + handler(ww, r) + + // Record metrics + duration := time.Since(startTime).Seconds() + metrics.RecordHTTPRequest(endpoint, r.Method, fmt.Sprintf("%d", ww.statusCode)) + metrics.RecordHTTPRequestDuration(endpoint, r.Method, duration) + } +} + +// responseWriterWrapper wraps http.ResponseWriter to capture status code +type responseWriterWrapper struct { + http.ResponseWriter + statusCode int +} + +func (w *responseWriterWrapper) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { case "DEBUG": @@ -136,6 +166,15 @@ func main() { localOverridesStr string trustedUpstreamsStr string proxyProtocol bool + + // Metrics configuration variables (set from env, then overridden by CLI flags) + metricsEnabled bool + metricsBackend string + metricsPath string + otelMetricsProtocol string + otelMetricsEndpoint string + otelMetricsInsecure bool + otelMetricsExportInterval time.Duration ) interfaceName = os.Getenv("INTERFACE") @@ -157,6 +196,40 @@ func main() { doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING") bandwidthLimitStr := os.Getenv("BANDWIDTH_LIMIT") + // Read metrics env vars (defaults applied by DefaultMetricsConfig; these override defaults). + metricsEnabled = true // default + if v := os.Getenv("METRICS_ENABLED"); v != "" { + metricsEnabled = strings.ToLower(v) == "true" + } + metricsBackend = "prometheus" // default + if v := os.Getenv("METRICS_BACKEND"); v != "" { + metricsBackend = v + } + metricsPath = "/metrics" // default + if v := os.Getenv("METRICS_PATH"); v != "" { + metricsPath = v + } + otelMetricsProtocol = "grpc" // default + if v := os.Getenv("OTEL_METRICS_PROTOCOL"); v != "" { + otelMetricsProtocol = v + } + otelMetricsEndpoint = "localhost:4317" // default + if v := os.Getenv("OTEL_METRICS_ENDPOINT"); v != "" { + otelMetricsEndpoint = v + } + otelMetricsInsecure = true // default + if v := os.Getenv("OTEL_METRICS_INSECURE"); v != "" { + otelMetricsInsecure = strings.ToLower(v) == "true" + } + otelMetricsExportInterval = 60 * time.Second // default + if v := os.Getenv("OTEL_METRICS_EXPORT_INTERVAL"); v != "" { + if d, err2 := time.ParseDuration(v); err2 == nil { + otelMetricsExportInterval = d + } else { + log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2) + } + } + if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") } @@ -241,6 +314,15 @@ func main() { flag.StringVar(&bandwidthLimit, "bandwidth-limit", "50mbit", "Bandwidth limit per peer for traffic shaping (e.g. 50mbit, 1gbit)") } + // Metrics CLI flags – always registered so that CLI overrides env/defaults. + flag.BoolVar(&metricsEnabled, "metrics-enabled", metricsEnabled, "Enable metrics collection (default: true)") + flag.StringVar(&metricsBackend, "metrics-backend", metricsBackend, "Metrics backend: prometheus, otel, or none") + flag.StringVar(&metricsPath, "metrics-path", metricsPath, "HTTP path for Prometheus /metrics endpoint") + flag.StringVar(&otelMetricsProtocol, "otel-metrics-protocol", otelMetricsProtocol, "OTLP transport protocol: grpc or http") + flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)") + flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection") + flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes") + flag.Parse() // Derive IFB device name from the WireGuard interface name (Linux limit: 15 chars) @@ -252,6 +334,38 @@ func main() { logger.Init() logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + // Initialize metrics with the selected backend. + // Config precedence: CLI flags > env vars > defaults (already applied above). + metricsHandler, err := metrics.Initialize(metrics.Config{ + Enabled: metricsEnabled, + Backend: metricsBackend, + Prometheus: metrics.PrometheusConfig{ + Path: metricsPath, + }, + OTel: metrics.OTelConfig{ + Protocol: otelMetricsProtocol, + Endpoint: otelMetricsEndpoint, + Insecure: otelMetricsInsecure, + ExportInterval: otelMetricsExportInterval, + }, + ServiceName: "gerbil", + ServiceVersion: "1.0.0", + DeploymentEnvironment: os.Getenv("DEPLOYMENT_ENVIRONMENT"), + }) + if err != nil { + logger.Fatal("Failed to initialize metrics: %v", err) + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := metrics.Shutdown(shutdownCtx); err != nil { + logger.Error("Failed to shutdown metrics: %v", err) + } + }() + + // Record restart metric + metrics.RecordRestart() + // Base context for the application; cancel on SIGINT/SIGTERM ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -420,18 +534,27 @@ func main() { logger.Fatal("Failed to start proxy: %v", err) } - // Set up HTTP server - http.HandleFunc("/peer", handlePeer) - http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) - http.HandleFunc("/update-destinations", handleUpdateDestinations) - http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) - http.HandleFunc("/healthz", handleHealthz) + // Set up HTTP server with metrics middleware + http.HandleFunc("/peer", httpMetricsMiddleware("peer", handlePeer)) + http.HandleFunc("/update-proxy-mapping", httpMetricsMiddleware("update_proxy_mapping", handleUpdateProxyMapping)) + http.HandleFunc("/update-destinations", httpMetricsMiddleware("update_destinations", handleUpdateDestinations)) + http.HandleFunc("/update-local-snis", httpMetricsMiddleware("update_local_snis", handleUpdateLocalSNIs)) + http.HandleFunc("/healthz", httpMetricsMiddleware("healthz", handleHealthz)) + + // Register metrics endpoint only for Prometheus backend. + // OTel backend pushes to a collector; no /metrics endpoint needed. + if metricsHandler != nil { + http.Handle(metricsPath, metricsHandler) + logger.Info("Metrics endpoint enabled at %s", metricsPath) + } + logger.Info("Starting HTTP server on %s", listenAddr) // HTTP server with graceful shutdown on context cancel server := &http.Server{ - Addr: listenAddr, - Handler: nil, + Addr: listenAddr, + Handler: nil, + ReadHeaderTimeout: 3 * time.Second, } group.Go(func() error { // http.ErrServerClosed is returned on graceful shutdown; not an error for us @@ -466,26 +589,35 @@ func main() { func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { var body *bytes.Buffer if reachableAt == "" { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String()))) + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q}`, key.PublicKey().String()))) } else { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt))) + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q, "reachableAt": %q}`, key.PublicKey().String(), reachableAt))) } resp, err := http.Post(url, "application/json", body) if err != nil { // print the error logger.Error("Error fetching remote config %s: %v", url, err) + // Record remote config fetch error + metrics.RecordRemoteConfigFetch("error") return WgConfig{}, err } defer resp.Body.Close() data, err := io.ReadAll(resp.Body) if err != nil { + metrics.RecordRemoteConfigFetch("error") return WgConfig{}, err } var config WgConfig err = json.Unmarshal(data, &config) + if err != nil { + metrics.RecordRemoteConfigFetch("error") + return config, err + } + // Record successful remote config fetch + metrics.RecordRemoteConfigFetch("success") return config, err } @@ -593,6 +725,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error { logger.Info("WireGuard interface %s created and configured", interfaceName) + // Record interface state metric + hostname, _ := os.Hostname() + metrics.RecordInterfaceUp(interfaceName, hostname, true) + return nil } @@ -890,15 +1026,22 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) { var peer Peer if err := json.NewDecoder(r.Body).Decode(&peer); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) + // Record peer add error + metrics.RecordPeerOperation("add", "error") return } err := addPeer(peer) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) + // Record peer add error + metrics.RecordPeerOperation("add", "error") return } + // Record peer add success + metrics.RecordPeerOperation("add", "success") + // Notify if notifyURL is set go notifyPeerChange("add", peer.PublicKey) @@ -971,6 +1114,10 @@ func addPeerInternal(peer Peer) error { logger.Info("Peer %s added successfully", peer.PublicKey) + // Record metrics + metrics.RecordPeersTotal(interfaceName, 1) + metrics.RecordAllowedIPsCount(interfaceName, peer.PublicKey, int64(len(peer.AllowedIPs))) + return nil } @@ -978,15 +1125,22 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) { publicKey := r.URL.Query().Get("public_key") if publicKey == "" { http.Error(w, "Missing public_key query parameter", http.StatusBadRequest) + // Record peer remove error + metrics.RecordPeerOperation("remove", "error") return } err := removePeer(publicKey) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) + // Record peer remove error + metrics.RecordPeerOperation("remove", "error") return } + // Record peer remove success + metrics.RecordPeerOperation("remove", "success") + // Notify if notifyURL is set go notifyPeerChange("remove", publicKey) @@ -1052,6 +1206,10 @@ func removePeerInternal(publicKey string) error { logger.Info("Peer %s removed successfully", publicKey) + // Record metrics + metrics.RecordPeersTotal(interfaceName, -1) + metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs))) + return nil } @@ -1086,6 +1244,8 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) + // Record error + metrics.RecordProxyMappingUpdateRequest("error") return } @@ -1096,6 +1256,9 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { update.OldDestination.DestinationIP, update.OldDestination.DestinationPort, update.NewDestination.DestinationIP, update.NewDestination.DestinationPort) + // Record success + metrics.RecordProxyMappingUpdateRequest("success") + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]interface{}{ "status": "Proxy mappings updated successfully", @@ -1156,6 +1319,8 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) + // Record error + metrics.RecordDestinationsUpdateRequest("error") return } @@ -1164,6 +1329,9 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { logger.Info("Updated proxy mapping for %s:%d with %d destinations", request.SourceIP, request.SourcePort, len(request.Destinations)) + // Record success + metrics.RecordDestinationsUpdateRequest("success") + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]interface{}{ "status": "Destinations updated successfully", @@ -1225,7 +1393,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { return nil, fmt.Errorf("failed to get device: %v", err) } - peerBandwidths := []PeerBandwidth{} + var peerBandwidths []PeerBandwidth now := time.Now() mu.Lock() @@ -1266,6 +1434,14 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { bytesInMB := bytesInDiff / (1024 * 1024) bytesOutMB := bytesOutDiff / (1024 * 1024) + // Record metrics (in bytes) + if bytesInDiff > 0 { + metrics.RecordBytesReceived(interfaceName, publicKey, int64(bytesInDiff)) + } + if bytesOutDiff > 0 { + metrics.RecordBytesTransmitted(interfaceName, publicKey, int64(bytesOutDiff)) + } + peerBandwidths = append(peerBandwidths, PeerBandwidth{ PublicKey: publicKey, BytesIn: bytesInMB, @@ -1305,24 +1481,31 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { func reportPeerBandwidth(apiURL string) error { bandwidths, err := calculatePeerBandwidth() if err != nil { + // Record bandwidth report error + metrics.RecordBandwidthReport("error") return fmt.Errorf("failed to calculate peer bandwidth: %v", err) } jsonData, err := json.Marshal(bandwidths) if err != nil { + metrics.RecordBandwidthReport("error") return fmt.Errorf("failed to marshal bandwidth data: %v", err) } resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) if err != nil { + metrics.RecordBandwidthReport("error") return fmt.Errorf("failed to send bandwidth data: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + metrics.RecordBandwidthReport("error") return fmt.Errorf("API returned non-OK status: %s", resp.Status) } + // Record successful bandwidth report + metrics.RecordBandwidthReport("success") return nil } @@ -1356,14 +1539,25 @@ func monitorMemory(limit uint64) { for { runtime.ReadMemStats(&m) if m.Alloc > limit { + // Determine severity based on how much over the limit + severity := "warning" + if m.Alloc > limit*2 { + severity = "critical" + } + fmt.Printf("Memory spike detected (%d bytes). Dumping profile...\n", m.Alloc) + // Record memory spike metric + metrics.RecordMemorySpike(severity) + f, err := os.Create(fmt.Sprintf("/var/config/heap/heap-spike-%d.pprof", time.Now().Unix())) if err != nil { log.Println("could not create profile:", err) } else { pprof.WriteHeapProfile(f) f.Close() + // Record heap profile written metric + metrics.RecordHeapProfileWritten() } // Wait a while before checking again to avoid spamming profiles @@ -1522,7 +1716,7 @@ func setupPeerBandwidthLimit(peerIP string) error { if output, err := cmd.CombinedOutput(); err != nil { logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output)) } - + // Set up ingress shaping on the IFB device (peer's upload / ingress on wg0). // All wg0 ingress is redirected to ifb0 by ensureIFBDevice; we add a per-peer // class + src filter here so each peer gets its own independent rate limit. @@ -1626,7 +1820,7 @@ func removePeerBandwidthLimit(peerIP string) error { logger.Warn("Failed to remove egress class for peer IP %s: %v, output: %s", ip, err, string(output)) } } - + // Remove the ingress class and filters on the IFB device ifbClassID := fmt.Sprintf("2:%s", lastOctet)