Integrate metrics instrumentation across core services

This commit is contained in:
Marc Schäfer
2026-04-03 15:57:53 +02:00
parent f322b4c921
commit 4357ddf64b

220
main.go
View File

@@ -23,6 +23,7 @@ import (
"syscall"
"time"
"github.com/fosrl/gerbil/internal/metrics"
"github.com/fosrl/gerbil/logger"
"github.com/fosrl/gerbil/proxy"
"github.com/fosrl/gerbil/relay"
@@ -101,6 +102,35 @@ type UpdateDestinationsRequest struct {
Destinations []relay.PeerDestination `json:"destinations"`
}
// httpMetricsMiddleware wraps HTTP handlers with metrics tracking
func httpMetricsMiddleware(endpoint string, handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
// Create a response writer wrapper to capture status code
ww := &responseWriterWrapper{ResponseWriter: w, statusCode: http.StatusOK}
// Call the actual handler
handler(ww, r)
// Record metrics
duration := time.Since(startTime).Seconds()
metrics.RecordHTTPRequest(endpoint, r.Method, fmt.Sprintf("%d", ww.statusCode))
metrics.RecordHTTPRequestDuration(endpoint, r.Method, duration)
}
}
// responseWriterWrapper wraps http.ResponseWriter to capture status code
type responseWriterWrapper struct {
http.ResponseWriter
statusCode int
}
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func parseLogLevel(level string) logger.LogLevel {
switch strings.ToUpper(level) {
case "DEBUG":
@@ -136,6 +166,15 @@ func main() {
localOverridesStr string
trustedUpstreamsStr string
proxyProtocol bool
// Metrics configuration variables (set from env, then overridden by CLI flags)
metricsEnabled bool
metricsBackend string
metricsPath string
otelMetricsProtocol string
otelMetricsEndpoint string
otelMetricsInsecure bool
otelMetricsExportInterval time.Duration
)
interfaceName = os.Getenv("INTERFACE")
@@ -157,6 +196,40 @@ func main() {
doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING")
bandwidthLimitStr := os.Getenv("BANDWIDTH_LIMIT")
// Read metrics env vars (defaults applied by DefaultMetricsConfig; these override defaults).
metricsEnabled = true // default
if v := os.Getenv("METRICS_ENABLED"); v != "" {
metricsEnabled = strings.ToLower(v) == "true"
}
metricsBackend = "prometheus" // default
if v := os.Getenv("METRICS_BACKEND"); v != "" {
metricsBackend = v
}
metricsPath = "/metrics" // default
if v := os.Getenv("METRICS_PATH"); v != "" {
metricsPath = v
}
otelMetricsProtocol = "grpc" // default
if v := os.Getenv("OTEL_METRICS_PROTOCOL"); v != "" {
otelMetricsProtocol = v
}
otelMetricsEndpoint = "localhost:4317" // default
if v := os.Getenv("OTEL_METRICS_ENDPOINT"); v != "" {
otelMetricsEndpoint = v
}
otelMetricsInsecure = true // default
if v := os.Getenv("OTEL_METRICS_INSECURE"); v != "" {
otelMetricsInsecure = strings.ToLower(v) == "true"
}
otelMetricsExportInterval = 60 * time.Second // default
if v := os.Getenv("OTEL_METRICS_EXPORT_INTERVAL"); v != "" {
if d, err2 := time.ParseDuration(v); err2 == nil {
otelMetricsExportInterval = d
} else {
log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2)
}
}
if interfaceName == "" {
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
}
@@ -241,6 +314,15 @@ func main() {
flag.StringVar(&bandwidthLimit, "bandwidth-limit", "50mbit", "Bandwidth limit per peer for traffic shaping (e.g. 50mbit, 1gbit)")
}
// Metrics CLI flags always registered so that CLI overrides env/defaults.
flag.BoolVar(&metricsEnabled, "metrics-enabled", metricsEnabled, "Enable metrics collection (default: true)")
flag.StringVar(&metricsBackend, "metrics-backend", metricsBackend, "Metrics backend: prometheus, otel, or none")
flag.StringVar(&metricsPath, "metrics-path", metricsPath, "HTTP path for Prometheus /metrics endpoint")
flag.StringVar(&otelMetricsProtocol, "otel-metrics-protocol", otelMetricsProtocol, "OTLP transport protocol: grpc or http")
flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)")
flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection")
flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes")
flag.Parse()
// Derive IFB device name from the WireGuard interface name (Linux limit: 15 chars)
@@ -252,6 +334,38 @@ func main() {
logger.Init()
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
// Initialize metrics with the selected backend.
// Config precedence: CLI flags > env vars > defaults (already applied above).
metricsHandler, err := metrics.Initialize(metrics.Config{
Enabled: metricsEnabled,
Backend: metricsBackend,
Prometheus: metrics.PrometheusConfig{
Path: metricsPath,
},
OTel: metrics.OTelConfig{
Protocol: otelMetricsProtocol,
Endpoint: otelMetricsEndpoint,
Insecure: otelMetricsInsecure,
ExportInterval: otelMetricsExportInterval,
},
ServiceName: "gerbil",
ServiceVersion: "1.0.0",
DeploymentEnvironment: os.Getenv("DEPLOYMENT_ENVIRONMENT"),
})
if err != nil {
logger.Fatal("Failed to initialize metrics: %v", err)
}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := metrics.Shutdown(shutdownCtx); err != nil {
logger.Error("Failed to shutdown metrics: %v", err)
}
}()
// Record restart metric
metrics.RecordRestart()
// Base context for the application; cancel on SIGINT/SIGTERM
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
defer stop()
@@ -420,18 +534,27 @@ func main() {
logger.Fatal("Failed to start proxy: %v", err)
}
// Set up HTTP server
http.HandleFunc("/peer", handlePeer)
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
http.HandleFunc("/update-destinations", handleUpdateDestinations)
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
http.HandleFunc("/healthz", handleHealthz)
// Set up HTTP server with metrics middleware
http.HandleFunc("/peer", httpMetricsMiddleware("peer", handlePeer))
http.HandleFunc("/update-proxy-mapping", httpMetricsMiddleware("update_proxy_mapping", handleUpdateProxyMapping))
http.HandleFunc("/update-destinations", httpMetricsMiddleware("update_destinations", handleUpdateDestinations))
http.HandleFunc("/update-local-snis", httpMetricsMiddleware("update_local_snis", handleUpdateLocalSNIs))
http.HandleFunc("/healthz", httpMetricsMiddleware("healthz", handleHealthz))
// Register metrics endpoint only for Prometheus backend.
// OTel backend pushes to a collector; no /metrics endpoint needed.
if metricsHandler != nil {
http.Handle(metricsPath, metricsHandler)
logger.Info("Metrics endpoint enabled at %s", metricsPath)
}
logger.Info("Starting HTTP server on %s", listenAddr)
// HTTP server with graceful shutdown on context cancel
server := &http.Server{
Addr: listenAddr,
Handler: nil,
Addr: listenAddr,
Handler: nil,
ReadHeaderTimeout: 3 * time.Second,
}
group.Go(func() error {
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
@@ -466,26 +589,35 @@ func main() {
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
var body *bytes.Buffer
if reachableAt == "" {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String())))
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q}`, key.PublicKey().String())))
} else {
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt)))
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q, "reachableAt": %q}`, key.PublicKey().String(), reachableAt)))
}
resp, err := http.Post(url, "application/json", body)
if err != nil {
// print the error
logger.Error("Error fetching remote config %s: %v", url, err)
// Record remote config fetch error
metrics.RecordRemoteConfigFetch("error")
return WgConfig{}, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
metrics.RecordRemoteConfigFetch("error")
return WgConfig{}, err
}
var config WgConfig
err = json.Unmarshal(data, &config)
if err != nil {
metrics.RecordRemoteConfigFetch("error")
return config, err
}
// Record successful remote config fetch
metrics.RecordRemoteConfigFetch("success")
return config, err
}
@@ -593,6 +725,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
logger.Info("WireGuard interface %s created and configured", interfaceName)
// Record interface state metric
hostname, _ := os.Hostname()
metrics.RecordInterfaceUp(interfaceName, hostname, true)
return nil
}
@@ -890,15 +1026,22 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
var peer Peer
if err := json.NewDecoder(r.Body).Decode(&peer); err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
// Record peer add error
metrics.RecordPeerOperation("add", "error")
return
}
err := addPeer(peer)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
// Record peer add error
metrics.RecordPeerOperation("add", "error")
return
}
// Record peer add success
metrics.RecordPeerOperation("add", "success")
// Notify if notifyURL is set
go notifyPeerChange("add", peer.PublicKey)
@@ -971,6 +1114,10 @@ func addPeerInternal(peer Peer) error {
logger.Info("Peer %s added successfully", peer.PublicKey)
// Record metrics
metrics.RecordPeersTotal(interfaceName, 1)
metrics.RecordAllowedIPsCount(interfaceName, peer.PublicKey, int64(len(peer.AllowedIPs)))
return nil
}
@@ -978,15 +1125,22 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
publicKey := r.URL.Query().Get("public_key")
if publicKey == "" {
http.Error(w, "Missing public_key query parameter", http.StatusBadRequest)
// Record peer remove error
metrics.RecordPeerOperation("remove", "error")
return
}
err := removePeer(publicKey)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
// Record peer remove error
metrics.RecordPeerOperation("remove", "error")
return
}
// Record peer remove success
metrics.RecordPeerOperation("remove", "success")
// Notify if notifyURL is set
go notifyPeerChange("remove", publicKey)
@@ -1052,6 +1206,10 @@ func removePeerInternal(publicKey string) error {
logger.Info("Peer %s removed successfully", publicKey)
// Record metrics
metrics.RecordPeersTotal(interfaceName, -1)
metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs)))
return nil
}
@@ -1086,6 +1244,8 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
// Record error
metrics.RecordProxyMappingUpdateRequest("error")
return
}
@@ -1096,6 +1256,9 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
update.OldDestination.DestinationIP, update.OldDestination.DestinationPort,
update.NewDestination.DestinationIP, update.NewDestination.DestinationPort)
// Record success
metrics.RecordProxyMappingUpdateRequest("success")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Proxy mappings updated successfully",
@@ -1156,6 +1319,8 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
if proxyRelay == nil {
logger.Error("Proxy server is not available")
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
// Record error
metrics.RecordDestinationsUpdateRequest("error")
return
}
@@ -1164,6 +1329,9 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
request.SourceIP, request.SourcePort, len(request.Destinations))
// Record success
metrics.RecordDestinationsUpdateRequest("success")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": "Destinations updated successfully",
@@ -1225,7 +1393,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
return nil, fmt.Errorf("failed to get device: %v", err)
}
peerBandwidths := []PeerBandwidth{}
var peerBandwidths []PeerBandwidth
now := time.Now()
mu.Lock()
@@ -1266,6 +1434,14 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
bytesInMB := bytesInDiff / (1024 * 1024)
bytesOutMB := bytesOutDiff / (1024 * 1024)
// Record metrics (in bytes)
if bytesInDiff > 0 {
metrics.RecordBytesReceived(interfaceName, publicKey, int64(bytesInDiff))
}
if bytesOutDiff > 0 {
metrics.RecordBytesTransmitted(interfaceName, publicKey, int64(bytesOutDiff))
}
peerBandwidths = append(peerBandwidths, PeerBandwidth{
PublicKey: publicKey,
BytesIn: bytesInMB,
@@ -1305,24 +1481,31 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
func reportPeerBandwidth(apiURL string) error {
bandwidths, err := calculatePeerBandwidth()
if err != nil {
// Record bandwidth report error
metrics.RecordBandwidthReport("error")
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
}
jsonData, err := json.Marshal(bandwidths)
if err != nil {
metrics.RecordBandwidthReport("error")
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
}
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData))
if err != nil {
metrics.RecordBandwidthReport("error")
return fmt.Errorf("failed to send bandwidth data: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
metrics.RecordBandwidthReport("error")
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
}
// Record successful bandwidth report
metrics.RecordBandwidthReport("success")
return nil
}
@@ -1356,14 +1539,25 @@ func monitorMemory(limit uint64) {
for {
runtime.ReadMemStats(&m)
if m.Alloc > limit {
// Determine severity based on how much over the limit
severity := "warning"
if m.Alloc > limit*2 {
severity = "critical"
}
fmt.Printf("Memory spike detected (%d bytes). Dumping profile...\n", m.Alloc)
// Record memory spike metric
metrics.RecordMemorySpike(severity)
f, err := os.Create(fmt.Sprintf("/var/config/heap/heap-spike-%d.pprof", time.Now().Unix()))
if err != nil {
log.Println("could not create profile:", err)
} else {
pprof.WriteHeapProfile(f)
f.Close()
// Record heap profile written metric
metrics.RecordHeapProfileWritten()
}
// Wait a while before checking again to avoid spamming profiles
@@ -1522,7 +1716,7 @@ func setupPeerBandwidthLimit(peerIP string) error {
if output, err := cmd.CombinedOutput(); err != nil {
logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output))
}
// Set up ingress shaping on the IFB device (peer's upload / ingress on wg0).
// All wg0 ingress is redirected to ifb0 by ensureIFBDevice; we add a per-peer
// class + src filter here so each peer gets its own independent rate limit.
@@ -1626,7 +1820,7 @@ func removePeerBandwidthLimit(peerIP string) error {
logger.Warn("Failed to remove egress class for peer IP %s: %v, output: %s", ip, err, string(output))
}
}
// Remove the ingress class and filters on the IFB device
ifbClassID := fmt.Sprintf("2:%s", lastOctet)