mirror of
https://github.com/fosrl/gerbil.git
synced 2026-05-14 04:10:03 +00:00
Integrate metrics instrumentation across core services
This commit is contained in:
220
main.go
220
main.go
@@ -23,6 +23,7 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/internal/metrics"
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/fosrl/gerbil/proxy"
|
||||
"github.com/fosrl/gerbil/relay"
|
||||
@@ -101,6 +102,35 @@ type UpdateDestinationsRequest struct {
|
||||
Destinations []relay.PeerDestination `json:"destinations"`
|
||||
}
|
||||
|
||||
// httpMetricsMiddleware wraps HTTP handlers with metrics tracking
|
||||
func httpMetricsMiddleware(endpoint string, handler http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
|
||||
// Create a response writer wrapper to capture status code
|
||||
ww := &responseWriterWrapper{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
// Call the actual handler
|
||||
handler(ww, r)
|
||||
|
||||
// Record metrics
|
||||
duration := time.Since(startTime).Seconds()
|
||||
metrics.RecordHTTPRequest(endpoint, r.Method, fmt.Sprintf("%d", ww.statusCode))
|
||||
metrics.RecordHTTPRequestDuration(endpoint, r.Method, duration)
|
||||
}
|
||||
}
|
||||
|
||||
// responseWriterWrapper wraps http.ResponseWriter to capture status code
|
||||
type responseWriterWrapper struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.statusCode = statusCode
|
||||
w.ResponseWriter.WriteHeader(statusCode)
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) logger.LogLevel {
|
||||
switch strings.ToUpper(level) {
|
||||
case "DEBUG":
|
||||
@@ -136,6 +166,15 @@ func main() {
|
||||
localOverridesStr string
|
||||
trustedUpstreamsStr string
|
||||
proxyProtocol bool
|
||||
|
||||
// Metrics configuration variables (set from env, then overridden by CLI flags)
|
||||
metricsEnabled bool
|
||||
metricsBackend string
|
||||
metricsPath string
|
||||
otelMetricsProtocol string
|
||||
otelMetricsEndpoint string
|
||||
otelMetricsInsecure bool
|
||||
otelMetricsExportInterval time.Duration
|
||||
)
|
||||
|
||||
interfaceName = os.Getenv("INTERFACE")
|
||||
@@ -157,6 +196,40 @@ func main() {
|
||||
doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING")
|
||||
bandwidthLimitStr := os.Getenv("BANDWIDTH_LIMIT")
|
||||
|
||||
// Read metrics env vars (defaults applied by DefaultMetricsConfig; these override defaults).
|
||||
metricsEnabled = true // default
|
||||
if v := os.Getenv("METRICS_ENABLED"); v != "" {
|
||||
metricsEnabled = strings.ToLower(v) == "true"
|
||||
}
|
||||
metricsBackend = "prometheus" // default
|
||||
if v := os.Getenv("METRICS_BACKEND"); v != "" {
|
||||
metricsBackend = v
|
||||
}
|
||||
metricsPath = "/metrics" // default
|
||||
if v := os.Getenv("METRICS_PATH"); v != "" {
|
||||
metricsPath = v
|
||||
}
|
||||
otelMetricsProtocol = "grpc" // default
|
||||
if v := os.Getenv("OTEL_METRICS_PROTOCOL"); v != "" {
|
||||
otelMetricsProtocol = v
|
||||
}
|
||||
otelMetricsEndpoint = "localhost:4317" // default
|
||||
if v := os.Getenv("OTEL_METRICS_ENDPOINT"); v != "" {
|
||||
otelMetricsEndpoint = v
|
||||
}
|
||||
otelMetricsInsecure = true // default
|
||||
if v := os.Getenv("OTEL_METRICS_INSECURE"); v != "" {
|
||||
otelMetricsInsecure = strings.ToLower(v) == "true"
|
||||
}
|
||||
otelMetricsExportInterval = 60 * time.Second // default
|
||||
if v := os.Getenv("OTEL_METRICS_EXPORT_INTERVAL"); v != "" {
|
||||
if d, err2 := time.ParseDuration(v); err2 == nil {
|
||||
otelMetricsExportInterval = d
|
||||
} else {
|
||||
log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2)
|
||||
}
|
||||
}
|
||||
|
||||
if interfaceName == "" {
|
||||
flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface")
|
||||
}
|
||||
@@ -241,6 +314,15 @@ func main() {
|
||||
flag.StringVar(&bandwidthLimit, "bandwidth-limit", "50mbit", "Bandwidth limit per peer for traffic shaping (e.g. 50mbit, 1gbit)")
|
||||
}
|
||||
|
||||
// Metrics CLI flags – always registered so that CLI overrides env/defaults.
|
||||
flag.BoolVar(&metricsEnabled, "metrics-enabled", metricsEnabled, "Enable metrics collection (default: true)")
|
||||
flag.StringVar(&metricsBackend, "metrics-backend", metricsBackend, "Metrics backend: prometheus, otel, or none")
|
||||
flag.StringVar(&metricsPath, "metrics-path", metricsPath, "HTTP path for Prometheus /metrics endpoint")
|
||||
flag.StringVar(&otelMetricsProtocol, "otel-metrics-protocol", otelMetricsProtocol, "OTLP transport protocol: grpc or http")
|
||||
flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)")
|
||||
flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection")
|
||||
flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
// Derive IFB device name from the WireGuard interface name (Linux limit: 15 chars)
|
||||
@@ -252,6 +334,38 @@ func main() {
|
||||
logger.Init()
|
||||
logger.GetLogger().SetLevel(parseLogLevel(logLevel))
|
||||
|
||||
// Initialize metrics with the selected backend.
|
||||
// Config precedence: CLI flags > env vars > defaults (already applied above).
|
||||
metricsHandler, err := metrics.Initialize(metrics.Config{
|
||||
Enabled: metricsEnabled,
|
||||
Backend: metricsBackend,
|
||||
Prometheus: metrics.PrometheusConfig{
|
||||
Path: metricsPath,
|
||||
},
|
||||
OTel: metrics.OTelConfig{
|
||||
Protocol: otelMetricsProtocol,
|
||||
Endpoint: otelMetricsEndpoint,
|
||||
Insecure: otelMetricsInsecure,
|
||||
ExportInterval: otelMetricsExportInterval,
|
||||
},
|
||||
ServiceName: "gerbil",
|
||||
ServiceVersion: "1.0.0",
|
||||
DeploymentEnvironment: os.Getenv("DEPLOYMENT_ENVIRONMENT"),
|
||||
})
|
||||
if err != nil {
|
||||
logger.Fatal("Failed to initialize metrics: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := metrics.Shutdown(shutdownCtx); err != nil {
|
||||
logger.Error("Failed to shutdown metrics: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Record restart metric
|
||||
metrics.RecordRestart()
|
||||
|
||||
// Base context for the application; cancel on SIGINT/SIGTERM
|
||||
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
|
||||
defer stop()
|
||||
@@ -420,18 +534,27 @@ func main() {
|
||||
logger.Fatal("Failed to start proxy: %v", err)
|
||||
}
|
||||
|
||||
// Set up HTTP server
|
||||
http.HandleFunc("/peer", handlePeer)
|
||||
http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping)
|
||||
http.HandleFunc("/update-destinations", handleUpdateDestinations)
|
||||
http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs)
|
||||
http.HandleFunc("/healthz", handleHealthz)
|
||||
// Set up HTTP server with metrics middleware
|
||||
http.HandleFunc("/peer", httpMetricsMiddleware("peer", handlePeer))
|
||||
http.HandleFunc("/update-proxy-mapping", httpMetricsMiddleware("update_proxy_mapping", handleUpdateProxyMapping))
|
||||
http.HandleFunc("/update-destinations", httpMetricsMiddleware("update_destinations", handleUpdateDestinations))
|
||||
http.HandleFunc("/update-local-snis", httpMetricsMiddleware("update_local_snis", handleUpdateLocalSNIs))
|
||||
http.HandleFunc("/healthz", httpMetricsMiddleware("healthz", handleHealthz))
|
||||
|
||||
// Register metrics endpoint only for Prometheus backend.
|
||||
// OTel backend pushes to a collector; no /metrics endpoint needed.
|
||||
if metricsHandler != nil {
|
||||
http.Handle(metricsPath, metricsHandler)
|
||||
logger.Info("Metrics endpoint enabled at %s", metricsPath)
|
||||
}
|
||||
|
||||
logger.Info("Starting HTTP server on %s", listenAddr)
|
||||
|
||||
// HTTP server with graceful shutdown on context cancel
|
||||
server := &http.Server{
|
||||
Addr: listenAddr,
|
||||
Handler: nil,
|
||||
Addr: listenAddr,
|
||||
Handler: nil,
|
||||
ReadHeaderTimeout: 3 * time.Second,
|
||||
}
|
||||
group.Go(func() error {
|
||||
// http.ErrServerClosed is returned on graceful shutdown; not an error for us
|
||||
@@ -466,26 +589,35 @@ func main() {
|
||||
func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) {
|
||||
var body *bytes.Buffer
|
||||
if reachableAt == "" {
|
||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String())))
|
||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q}`, key.PublicKey().String())))
|
||||
} else {
|
||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt)))
|
||||
body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q, "reachableAt": %q}`, key.PublicKey().String(), reachableAt)))
|
||||
}
|
||||
resp, err := http.Post(url, "application/json", body)
|
||||
if err != nil {
|
||||
// print the error
|
||||
logger.Error("Error fetching remote config %s: %v", url, err)
|
||||
// Record remote config fetch error
|
||||
metrics.RecordRemoteConfigFetch("error")
|
||||
return WgConfig{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
metrics.RecordRemoteConfigFetch("error")
|
||||
return WgConfig{}, err
|
||||
}
|
||||
|
||||
var config WgConfig
|
||||
err = json.Unmarshal(data, &config)
|
||||
if err != nil {
|
||||
metrics.RecordRemoteConfigFetch("error")
|
||||
return config, err
|
||||
}
|
||||
|
||||
// Record successful remote config fetch
|
||||
metrics.RecordRemoteConfigFetch("success")
|
||||
return config, err
|
||||
}
|
||||
|
||||
@@ -593,6 +725,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error {
|
||||
|
||||
logger.Info("WireGuard interface %s created and configured", interfaceName)
|
||||
|
||||
// Record interface state metric
|
||||
hostname, _ := os.Hostname()
|
||||
metrics.RecordInterfaceUp(interfaceName, hostname, true)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -890,15 +1026,22 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) {
|
||||
var peer Peer
|
||||
if err := json.NewDecoder(r.Body).Decode(&peer); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
// Record peer add error
|
||||
metrics.RecordPeerOperation("add", "error")
|
||||
return
|
||||
}
|
||||
|
||||
err := addPeer(peer)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
// Record peer add error
|
||||
metrics.RecordPeerOperation("add", "error")
|
||||
return
|
||||
}
|
||||
|
||||
// Record peer add success
|
||||
metrics.RecordPeerOperation("add", "success")
|
||||
|
||||
// Notify if notifyURL is set
|
||||
go notifyPeerChange("add", peer.PublicKey)
|
||||
|
||||
@@ -971,6 +1114,10 @@ func addPeerInternal(peer Peer) error {
|
||||
|
||||
logger.Info("Peer %s added successfully", peer.PublicKey)
|
||||
|
||||
// Record metrics
|
||||
metrics.RecordPeersTotal(interfaceName, 1)
|
||||
metrics.RecordAllowedIPsCount(interfaceName, peer.PublicKey, int64(len(peer.AllowedIPs)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -978,15 +1125,22 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) {
|
||||
publicKey := r.URL.Query().Get("public_key")
|
||||
if publicKey == "" {
|
||||
http.Error(w, "Missing public_key query parameter", http.StatusBadRequest)
|
||||
// Record peer remove error
|
||||
metrics.RecordPeerOperation("remove", "error")
|
||||
return
|
||||
}
|
||||
|
||||
err := removePeer(publicKey)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
// Record peer remove error
|
||||
metrics.RecordPeerOperation("remove", "error")
|
||||
return
|
||||
}
|
||||
|
||||
// Record peer remove success
|
||||
metrics.RecordPeerOperation("remove", "success")
|
||||
|
||||
// Notify if notifyURL is set
|
||||
go notifyPeerChange("remove", publicKey)
|
||||
|
||||
@@ -1052,6 +1206,10 @@ func removePeerInternal(publicKey string) error {
|
||||
|
||||
logger.Info("Peer %s removed successfully", publicKey)
|
||||
|
||||
// Record metrics
|
||||
metrics.RecordPeersTotal(interfaceName, -1)
|
||||
metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs)))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1086,6 +1244,8 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
|
||||
if proxyRelay == nil {
|
||||
logger.Error("Proxy server is not available")
|
||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||
// Record error
|
||||
metrics.RecordProxyMappingUpdateRequest("error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1096,6 +1256,9 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) {
|
||||
update.OldDestination.DestinationIP, update.OldDestination.DestinationPort,
|
||||
update.NewDestination.DestinationIP, update.NewDestination.DestinationPort)
|
||||
|
||||
// Record success
|
||||
metrics.RecordProxyMappingUpdateRequest("success")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Proxy mappings updated successfully",
|
||||
@@ -1156,6 +1319,8 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
|
||||
if proxyRelay == nil {
|
||||
logger.Error("Proxy server is not available")
|
||||
http.Error(w, "Proxy server is not available", http.StatusInternalServerError)
|
||||
// Record error
|
||||
metrics.RecordDestinationsUpdateRequest("error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1164,6 +1329,9 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Info("Updated proxy mapping for %s:%d with %d destinations",
|
||||
request.SourceIP, request.SourcePort, len(request.Destinations))
|
||||
|
||||
// Record success
|
||||
metrics.RecordDestinationsUpdateRequest("success")
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(map[string]interface{}{
|
||||
"status": "Destinations updated successfully",
|
||||
@@ -1225,7 +1393,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
return nil, fmt.Errorf("failed to get device: %v", err)
|
||||
}
|
||||
|
||||
peerBandwidths := []PeerBandwidth{}
|
||||
var peerBandwidths []PeerBandwidth
|
||||
now := time.Now()
|
||||
|
||||
mu.Lock()
|
||||
@@ -1266,6 +1434,14 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
bytesInMB := bytesInDiff / (1024 * 1024)
|
||||
bytesOutMB := bytesOutDiff / (1024 * 1024)
|
||||
|
||||
// Record metrics (in bytes)
|
||||
if bytesInDiff > 0 {
|
||||
metrics.RecordBytesReceived(interfaceName, publicKey, int64(bytesInDiff))
|
||||
}
|
||||
if bytesOutDiff > 0 {
|
||||
metrics.RecordBytesTransmitted(interfaceName, publicKey, int64(bytesOutDiff))
|
||||
}
|
||||
|
||||
peerBandwidths = append(peerBandwidths, PeerBandwidth{
|
||||
PublicKey: publicKey,
|
||||
BytesIn: bytesInMB,
|
||||
@@ -1305,24 +1481,31 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) {
|
||||
func reportPeerBandwidth(apiURL string) error {
|
||||
bandwidths, err := calculatePeerBandwidth()
|
||||
if err != nil {
|
||||
// Record bandwidth report error
|
||||
metrics.RecordBandwidthReport("error")
|
||||
return fmt.Errorf("failed to calculate peer bandwidth: %v", err)
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(bandwidths)
|
||||
if err != nil {
|
||||
metrics.RecordBandwidthReport("error")
|
||||
return fmt.Errorf("failed to marshal bandwidth data: %v", err)
|
||||
}
|
||||
|
||||
resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
metrics.RecordBandwidthReport("error")
|
||||
return fmt.Errorf("failed to send bandwidth data: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
metrics.RecordBandwidthReport("error")
|
||||
return fmt.Errorf("API returned non-OK status: %s", resp.Status)
|
||||
}
|
||||
|
||||
// Record successful bandwidth report
|
||||
metrics.RecordBandwidthReport("success")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1356,14 +1539,25 @@ func monitorMemory(limit uint64) {
|
||||
for {
|
||||
runtime.ReadMemStats(&m)
|
||||
if m.Alloc > limit {
|
||||
// Determine severity based on how much over the limit
|
||||
severity := "warning"
|
||||
if m.Alloc > limit*2 {
|
||||
severity = "critical"
|
||||
}
|
||||
|
||||
fmt.Printf("Memory spike detected (%d bytes). Dumping profile...\n", m.Alloc)
|
||||
|
||||
// Record memory spike metric
|
||||
metrics.RecordMemorySpike(severity)
|
||||
|
||||
f, err := os.Create(fmt.Sprintf("/var/config/heap/heap-spike-%d.pprof", time.Now().Unix()))
|
||||
if err != nil {
|
||||
log.Println("could not create profile:", err)
|
||||
} else {
|
||||
pprof.WriteHeapProfile(f)
|
||||
f.Close()
|
||||
// Record heap profile written metric
|
||||
metrics.RecordHeapProfileWritten()
|
||||
}
|
||||
|
||||
// Wait a while before checking again to avoid spamming profiles
|
||||
@@ -1522,7 +1716,7 @@ func setupPeerBandwidthLimit(peerIP string) error {
|
||||
if output, err := cmd.CombinedOutput(); err != nil {
|
||||
logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||
}
|
||||
|
||||
|
||||
// Set up ingress shaping on the IFB device (peer's upload / ingress on wg0).
|
||||
// All wg0 ingress is redirected to ifb0 by ensureIFBDevice; we add a per-peer
|
||||
// class + src filter here so each peer gets its own independent rate limit.
|
||||
@@ -1626,7 +1820,7 @@ func removePeerBandwidthLimit(peerIP string) error {
|
||||
logger.Warn("Failed to remove egress class for peer IP %s: %v, output: %s", ip, err, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Remove the ingress class and filters on the IFB device
|
||||
ifbClassID := fmt.Sprintf("2:%s", lastOctet)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user