diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 4a92b9f..e38aa90 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -7,7 +7,9 @@ package metrics import ( "context" + "fmt" "net/http" + "sync" "github.com/fosrl/gerbil/internal/observability" ) @@ -24,6 +26,7 @@ type OTelConfig = observability.OTelConfig var ( backend observability.Backend + initMu sync.Mutex // Interface and peer metrics wgInterfaceUp observability.Int64Gauge @@ -55,7 +58,6 @@ var ( udpPacketSizeBytes observability.Histogram holePunchEventsTotal observability.Counter proxyMappingActive observability.UpDownCounter - sessionActive observability.UpDownCounter sessionRebuiltTotal observability.Counter commPatternActive observability.UpDownCounter proxyCleanupRemovedTotal observability.Counter @@ -107,6 +109,13 @@ func DefaultConfig() Config { // Initialize sets up the metrics system using the selected backend. // It returns the /metrics HTTP handler (non-nil only for Prometheus backend). func Initialize(cfg Config) (http.Handler, error) { + initMu.Lock() + defer initMu.Unlock() + + if backend != nil { + return backend.HTTPHandler(), nil + } + b, err := observability.New(cfg) if err != nil { return nil, err @@ -114,6 +123,7 @@ func Initialize(cfg Config) (http.Handler, error) { backend = b if err := createInstruments(); err != nil { + backend = nil return nil, err } @@ -122,8 +132,13 @@ func Initialize(cfg Config) (http.Handler, error) { // Shutdown gracefully shuts down the metrics backend. func Shutdown(ctx context.Context) error { - if backend != nil { - return backend.Shutdown(ctx) + initMu.Lock() + b := backend + backend = nil + initMu.Unlock() + + if b != nil { + return b.Shutdown(ctx) } return nil } @@ -135,129 +150,346 @@ func createInstruments() error { b := backend - wgInterfaceUp = b.NewInt64Gauge("gerbil_wg_interface_up", + newCounter := func(name, desc string, labelNames ...string) (observability.Counter, error) { + c, err := b.NewCounter(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create counter %q: %w", name, err) + } + return c, nil + } + + newUpDownCounter := func(name, desc string, labelNames ...string) (observability.UpDownCounter, error) { + c, err := b.NewUpDownCounter(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create updown counter %q: %w", name, err) + } + return c, nil + } + + newInt64Gauge := func(name, desc string, labelNames ...string) (observability.Int64Gauge, error) { + g, err := b.NewInt64Gauge(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create int64 gauge %q: %w", name, err) + } + return g, nil + } + + newFloat64Gauge := func(name, desc string, labelNames ...string) (observability.Float64Gauge, error) { + g, err := b.NewFloat64Gauge(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create float64 gauge %q: %w", name, err) + } + return g, nil + } + + newHistogram := func(name, desc string, buckets []float64, labelNames ...string) (observability.Histogram, error) { + h, err := b.NewHistogram(name, desc, buckets, labelNames...) + if err != nil { + return nil, fmt.Errorf("create histogram %q: %w", name, err) + } + return h, nil + } + + var err error + + wgInterfaceUp, err = newInt64Gauge("gerbil_wg_interface_up", "Operational state of a WireGuard interface (1=up, 0=down)", "ifname", "instance") - wgPeersTotal = b.NewUpDownCounter("gerbil_wg_peers_total", + if err != nil { + return err + } + wgPeersTotal, err = newUpDownCounter("gerbil_wg_peers_total", "Total number of configured peers per interface", "ifname") - wgPeerConnected = b.NewInt64Gauge("gerbil_wg_peer_connected", + if err != nil { + return err + } + wgPeerConnected, err = newInt64Gauge("gerbil_wg_peer_connected", "Whether a specific peer is connected (1=connected, 0=disconnected)", "ifname", "peer") - allowedIPsCount = b.NewUpDownCounter("gerbil_allowed_ips_count", + if err != nil { + return err + } + allowedIPsCount, err = newUpDownCounter("gerbil_allowed_ips_count", "Number of allowed IPs configured per peer", "ifname", "peer") - keyRotationTotal = b.NewCounter("gerbil_key_rotation_total", + if err != nil { + return err + } + keyRotationTotal, err = newCounter("gerbil_key_rotation_total", "Key rotation events", "ifname", "reason") - wgHandshakesTotal = b.NewCounter("gerbil_wg_handshakes_total", + if err != nil { + return err + } + wgHandshakesTotal, err = newCounter("gerbil_wg_handshakes_total", "Count of handshake attempts with their result status", "ifname", "peer", "result") - wgHandshakeLatency = b.NewHistogram("gerbil_wg_handshake_latency_seconds", + if err != nil { + return err + } + wgHandshakeLatency, err = newHistogram("gerbil_wg_handshake_latency_seconds", "Distribution of handshake latencies in seconds", durationBuckets, "ifname", "peer") - wgPeerRTT = b.NewHistogram("gerbil_wg_peer_rtt_seconds", + if err != nil { + return err + } + wgPeerRTT, err = newHistogram("gerbil_wg_peer_rtt_seconds", "Observed round-trip time to a peer in seconds", durationBuckets, "ifname", "peer") - wgBytesReceived = b.NewCounter("gerbil_wg_bytes_received_total", + if err != nil { + return err + } + wgBytesReceived, err = newCounter("gerbil_wg_bytes_received_total", "Number of bytes received from a peer", "ifname", "peer") - wgBytesTransmitted = b.NewCounter("gerbil_wg_bytes_transmitted_total", + if err != nil { + return err + } + wgBytesTransmitted, err = newCounter("gerbil_wg_bytes_transmitted_total", "Number of bytes transmitted to a peer", "ifname", "peer") - netlinkEventsTotal = b.NewCounter("gerbil_netlink_events_total", + if err != nil { + return err + } + netlinkEventsTotal, err = newCounter("gerbil_netlink_events_total", "Number of netlink events processed", "event_type") - netlinkErrorsTotal = b.NewCounter("gerbil_netlink_errors_total", + if err != nil { + return err + } + netlinkErrorsTotal, err = newCounter("gerbil_netlink_errors_total", "Count of netlink or kernel errors", "component", "error_type") - syncDuration = b.NewHistogram("gerbil_sync_duration_seconds", + if err != nil { + return err + } + syncDuration, err = newHistogram("gerbil_sync_duration_seconds", "Duration of reconciliation/sync loops in seconds", durationBuckets, "component") - workqueueDepth = b.NewUpDownCounter("gerbil_workqueue_depth", + if err != nil { + return err + } + workqueueDepth, err = newUpDownCounter("gerbil_workqueue_depth", "Current length of internal work queues", "queue") - kernelModuleLoads = b.NewCounter("gerbil_kernel_module_loads_total", + if err != nil { + return err + } + kernelModuleLoads, err = newCounter("gerbil_kernel_module_loads_total", "Count of kernel module load attempts", "result") - firewallRulesApplied = b.NewCounter("gerbil_firewall_rules_applied_total", + if err != nil { + return err + } + firewallRulesApplied, err = newCounter("gerbil_firewall_rules_applied_total", "IPTables/NFT rules applied", "result", "chain") - activeSessions = b.NewUpDownCounter("gerbil_active_sessions", + if err != nil { + return err + } + activeSessions, err = newUpDownCounter("gerbil_active_sessions", "Number of active UDP relay sessions", "ifname") - activeProxyConnections = b.NewUpDownCounter("gerbil_active_proxy_connections", + if err != nil { + return err + } + activeProxyConnections, err = newUpDownCounter("gerbil_active_proxy_connections", "Active SNI proxy connections") - proxyRouteLookups = b.NewCounter("gerbil_proxy_route_lookups_total", + if err != nil { + return err + } + proxyRouteLookups, err = newCounter("gerbil_proxy_route_lookups_total", "Number of route lookups", "result") - proxyTLSHandshake = b.NewHistogram("gerbil_proxy_tls_handshake_seconds", + if err != nil { + return err + } + proxyTLSHandshake, err = newHistogram("gerbil_proxy_tls_handshake_seconds", "TLS handshake duration for SNI proxy in seconds", durationBuckets) - proxyBytesTransmitted = b.NewCounter("gerbil_proxy_bytes_transmitted_total", + if err != nil { + return err + } + proxyBytesTransmitted, err = newCounter("gerbil_proxy_bytes_transmitted_total", "Bytes sent/received by the SNI proxy", "direction") - configReloadsTotal = b.NewCounter("gerbil_config_reloads_total", + if err != nil { + return err + } + configReloadsTotal, err = newCounter("gerbil_config_reloads_total", "Number of configuration reloads", "result") - restartTotal = b.NewCounter("gerbil_restart_total", + if err != nil { + return err + } + restartTotal, err = newCounter("gerbil_restart_total", "Process restart count") - authFailuresTotal = b.NewCounter("gerbil_auth_failures_total", + if err != nil { + return err + } + authFailuresTotal, err = newCounter("gerbil_auth_failures_total", "Count of authentication or peer validation failures", "peer", "reason") - aclDeniedTotal = b.NewCounter("gerbil_acl_denied_total", + if err != nil { + return err + } + aclDeniedTotal, err = newCounter("gerbil_acl_denied_total", "Access control denied events", "ifname", "peer", "policy") - certificateExpiryDays = b.NewFloat64Gauge("gerbil_certificate_expiry_days", + if err != nil { + return err + } + certificateExpiryDays, err = newFloat64Gauge("gerbil_certificate_expiry_days", "Days until certificate expiry", "cert_name", "ifname") - udpPacketsTotal = b.NewCounter("gerbil_udp_packets_total", + if err != nil { + return err + } + udpPacketsTotal, err = newCounter("gerbil_udp_packets_total", "Count of UDP packets processed by relay workers", "ifname", "type", "direction") - udpPacketSizeBytes = b.NewHistogram("gerbil_udp_packet_size_bytes", + if err != nil { + return err + } + udpPacketSizeBytes, err = newHistogram("gerbil_udp_packet_size_bytes", "Size distribution of packets forwarded through relay", sizeBuckets, "ifname", "type") - holePunchEventsTotal = b.NewCounter("gerbil_hole_punch_events_total", + if err != nil { + return err + } + holePunchEventsTotal, err = newCounter("gerbil_hole_punch_events_total", "Count of hole punch messages processed", "ifname", "result") - proxyMappingActive = b.NewUpDownCounter("gerbil_proxy_mapping_active", + if err != nil { + return err + } + proxyMappingActive, err = newUpDownCounter("gerbil_proxy_mapping_active", "Number of active proxy mappings", "ifname") - sessionActive = b.NewUpDownCounter("gerbil_session_active", - "Number of active WireGuard sessions", "ifname") - sessionRebuiltTotal = b.NewCounter("gerbil_session_rebuilt_total", + if err != nil { + return err + } + sessionRebuiltTotal, err = newCounter("gerbil_session_rebuilt_total", "Count of sessions rebuilt from communication patterns", "ifname") - commPatternActive = b.NewUpDownCounter("gerbil_comm_pattern_active", + if err != nil { + return err + } + commPatternActive, err = newUpDownCounter("gerbil_comm_pattern_active", "Number of active communication patterns", "ifname") - proxyCleanupRemovedTotal = b.NewCounter("gerbil_proxy_cleanup_removed_total", + if err != nil { + return err + } + proxyCleanupRemovedTotal, err = newCounter("gerbil_proxy_cleanup_removed_total", "Count of items removed during cleanup routines", "ifname", "component") - proxyConnectionErrorsTotal = b.NewCounter("gerbil_proxy_connection_errors_total", + if err != nil { + return err + } + proxyConnectionErrorsTotal, err = newCounter("gerbil_proxy_connection_errors_total", "Count of connection errors in proxy operations", "ifname", "error_type") - proxyInitialMappingsTotal = b.NewInt64Gauge("gerbil_proxy_initial_mappings", + if err != nil { + return err + } + proxyInitialMappingsTotal, err = newInt64Gauge("gerbil_proxy_initial_mappings", "Number of initial proxy mappings loaded", "ifname") - proxyMappingUpdatesTotal = b.NewCounter("gerbil_proxy_mapping_updates_total", + if err != nil { + return err + } + proxyMappingUpdatesTotal, err = newCounter("gerbil_proxy_mapping_updates_total", "Count of proxy mapping updates", "ifname") - proxyIdleCleanupDuration = b.NewHistogram("gerbil_proxy_idle_cleanup_duration_seconds", + if err != nil { + return err + } + proxyIdleCleanupDuration, err = newHistogram("gerbil_proxy_idle_cleanup_duration_seconds", "Duration of cleanup cycles", durationBuckets, "ifname", "component") - sniConnectionsTotal = b.NewCounter("gerbil_sni_connections_total", + if err != nil { + return err + } + sniConnectionsTotal, err = newCounter("gerbil_sni_connections_total", "Count of connections processed by SNI proxy", "result") - sniConnectionDuration = b.NewHistogram("gerbil_sni_connection_duration_seconds", + if err != nil { + return err + } + sniConnectionDuration, err = newHistogram("gerbil_sni_connection_duration_seconds", "Lifetime distribution of proxied TLS connections", sniDurationBuckets) - sniActiveConnections = b.NewUpDownCounter("gerbil_sni_active_connections", + if err != nil { + return err + } + sniActiveConnections, err = newUpDownCounter("gerbil_sni_active_connections", "Number of active SNI tunnels") - sniRouteCacheHitsTotal = b.NewCounter("gerbil_sni_route_cache_hits_total", + if err != nil { + return err + } + sniRouteCacheHitsTotal, err = newCounter("gerbil_sni_route_cache_hits_total", "Count of route cache hits and misses", "result") - sniRouteAPIRequestsTotal = b.NewCounter("gerbil_sni_route_api_requests_total", + if err != nil { + return err + } + sniRouteAPIRequestsTotal, err = newCounter("gerbil_sni_route_api_requests_total", "Count of route API requests", "result") - sniRouteAPILatency = b.NewHistogram("gerbil_sni_route_api_latency_seconds", + if err != nil { + return err + } + sniRouteAPILatency, err = newHistogram("gerbil_sni_route_api_latency_seconds", "Distribution of route API call latencies", durationBuckets) - sniLocalOverrideTotal = b.NewCounter("gerbil_sni_local_override_total", + if err != nil { + return err + } + sniLocalOverrideTotal, err = newCounter("gerbil_sni_local_override_total", "Count of routes using local overrides", "hit") - sniTrustedProxyEventsTotal = b.NewCounter("gerbil_sni_trusted_proxy_events_total", + if err != nil { + return err + } + sniTrustedProxyEventsTotal, err = newCounter("gerbil_sni_trusted_proxy_events_total", "Count of PROXY protocol events", "event") - sniProxyProtocolParseErrorsTotal = b.NewCounter("gerbil_sni_proxy_protocol_parse_errors_total", + if err != nil { + return err + } + sniProxyProtocolParseErrorsTotal, err = newCounter("gerbil_sni_proxy_protocol_parse_errors_total", "Count of PROXY protocol parse failures") - sniDataBytesTotal = b.NewCounter("gerbil_sni_data_bytes_total", + if err != nil { + return err + } + sniDataBytesTotal, err = newCounter("gerbil_sni_data_bytes_total", "Count of bytes proxied through SNI tunnels", "direction") - sniTunnelTerminationsTotal = b.NewCounter("gerbil_sni_tunnel_terminations_total", + if err != nil { + return err + } + sniTunnelTerminationsTotal, err = newCounter("gerbil_sni_tunnel_terminations_total", "Count of tunnel terminations by reason", "reason") - httpRequestsTotal = b.NewCounter("gerbil_http_requests_total", + if err != nil { + return err + } + httpRequestsTotal, err = newCounter("gerbil_http_requests_total", "Count of HTTP requests to management API", "endpoint", "method", "status_code") - httpRequestDuration = b.NewHistogram("gerbil_http_request_duration_seconds", + if err != nil { + return err + } + httpRequestDuration, err = newHistogram("gerbil_http_request_duration_seconds", "Distribution of HTTP request handling time", durationBuckets, "endpoint", "method") - peerOperationsTotal = b.NewCounter("gerbil_peer_operations_total", + if err != nil { + return err + } + peerOperationsTotal, err = newCounter("gerbil_peer_operations_total", "Count of peer lifecycle operations", "operation", "result") - proxyMappingUpdateRequestsTotal = b.NewCounter("gerbil_proxy_mapping_update_requests_total", + if err != nil { + return err + } + proxyMappingUpdateRequestsTotal, err = newCounter("gerbil_proxy_mapping_update_requests_total", "Count of proxy mapping update API calls", "result") - destinationsUpdateRequestsTotal = b.NewCounter("gerbil_destinations_update_requests_total", + if err != nil { + return err + } + destinationsUpdateRequestsTotal, err = newCounter("gerbil_destinations_update_requests_total", "Count of destinations update API calls", "result") - remoteConfigFetchesTotal = b.NewCounter("gerbil_remote_config_fetches_total", + if err != nil { + return err + } + remoteConfigFetchesTotal, err = newCounter("gerbil_remote_config_fetches_total", "Count of remote configuration fetch attempts", "result") - bandwidthReportsTotal = b.NewCounter("gerbil_bandwidth_reports_total", + if err != nil { + return err + } + bandwidthReportsTotal, err = newCounter("gerbil_bandwidth_reports_total", "Count of bandwidth report transmissions", "result") - peerBandwidthBytesTotal = b.NewCounter("gerbil_peer_bandwidth_bytes_total", + if err != nil { + return err + } + peerBandwidthBytesTotal, err = newCounter("gerbil_peer_bandwidth_bytes_total", "Bytes per peer tracked by bandwidth calculation", "peer", "direction") - memorySpikeTotal = b.NewCounter("gerbil_memory_spike_total", + if err != nil { + return err + } + memorySpikeTotal, err = newCounter("gerbil_memory_spike_total", "Count of memory spikes detected", "severity") - heapProfilesWrittenTotal = b.NewCounter("gerbil_heap_profiles_written_total", + if err != nil { + return err + } + heapProfilesWrittenTotal, err = newCounter("gerbil_heap_profiles_written_total", "Count of heap profile files generated") + if err != nil { + return err + } return nil } func RecordInterfaceUp(ifname, instance string, up bool) { + if wgInterfaceUp == nil { + return + } value := int64(0) if up { value = 1 @@ -266,10 +498,16 @@ func RecordInterfaceUp(ifname, instance string, up bool) { } func RecordPeersTotal(ifname string, delta int64) { + if wgPeersTotal == nil { + return + } wgPeersTotal.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordPeerConnected(ifname, peer string, connected bool) { + if wgPeerConnected == nil { + return + } value := int64(0) if connected { value = 1 @@ -278,229 +516,393 @@ func RecordPeerConnected(ifname, peer string, connected bool) { } func RecordHandshake(ifname, peer, result string) { + if wgHandshakesTotal == nil { + return + } wgHandshakesTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "peer": peer, "result": result}) } func RecordHandshakeLatency(ifname, peer string, seconds float64) { + if wgHandshakeLatency == nil { + return + } wgHandshakeLatency.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordPeerRTT(ifname, peer string, seconds float64) { + if wgPeerRTT == nil { + return + } wgPeerRTT.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordBytesReceived(ifname, peer string, bytes int64) { + if wgBytesReceived == nil { + return + } wgBytesReceived.Add(context.Background(), bytes, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordBytesTransmitted(ifname, peer string, bytes int64) { + if wgBytesTransmitted == nil { + return + } wgBytesTransmitted.Add(context.Background(), bytes, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordAllowedIPsCount(ifname, peer string, delta int64) { + if allowedIPsCount == nil { + return + } allowedIPsCount.Add(context.Background(), delta, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordKeyRotation(ifname, reason string) { + if keyRotationTotal == nil { + return + } keyRotationTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "reason": reason}) } func RecordNetlinkEvent(eventType string) { + if netlinkEventsTotal == nil { + return + } netlinkEventsTotal.Add(context.Background(), 1, observability.Labels{"event_type": eventType}) } func RecordNetlinkError(component, errorType string) { + if netlinkErrorsTotal == nil { + return + } netlinkErrorsTotal.Add(context.Background(), 1, observability.Labels{"component": component, "error_type": errorType}) } func RecordSyncDuration(component string, seconds float64) { + if syncDuration == nil { + return + } syncDuration.Record(context.Background(), seconds, observability.Labels{"component": component}) } func RecordWorkqueueDepth(queue string, delta int64) { + if workqueueDepth == nil { + return + } workqueueDepth.Add(context.Background(), delta, observability.Labels{"queue": queue}) } func RecordKernelModuleLoad(result string) { + if kernelModuleLoads == nil { + return + } kernelModuleLoads.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordFirewallRuleApplied(result, chain string) { + if firewallRulesApplied == nil { + return + } firewallRulesApplied.Add(context.Background(), 1, observability.Labels{"result": result, "chain": chain}) } func RecordActiveSession(ifname string, delta int64) { + if activeSessions == nil { + return + } activeSessions.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } -func RecordActiveProxyConnection(hostname string, delta int64) { - _ = hostname +func RecordActiveProxyConnection(delta int64) { + if activeProxyConnections == nil { + return + } activeProxyConnections.Add(context.Background(), delta, nil) } -func RecordProxyRouteLookup(result, hostname string) { - _ = hostname +func RecordProxyRouteLookup(result string) { + if proxyRouteLookups == nil { + return + } proxyRouteLookups.Add(context.Background(), 1, observability.Labels{"result": result}) } -func RecordProxyTLSHandshake(hostname string, seconds float64) { - _ = hostname +func RecordProxyTLSHandshake(seconds float64) { + if proxyTLSHandshake == nil { + return + } proxyTLSHandshake.Record(context.Background(), seconds, nil) } -func RecordProxyBytesTransmitted(hostname, direction string, bytes int64) { - _ = hostname +func RecordProxyBytesTransmitted(direction string, bytes int64) { + if proxyBytesTransmitted == nil { + return + } proxyBytesTransmitted.Add(context.Background(), bytes, observability.Labels{"direction": direction}) } func RecordConfigReload(result string) { + if configReloadsTotal == nil { + return + } configReloadsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordRestart() { + if restartTotal == nil { + return + } restartTotal.Add(context.Background(), 1, nil) } func RecordAuthFailure(peer, reason string) { + if authFailuresTotal == nil { + return + } authFailuresTotal.Add(context.Background(), 1, observability.Labels{"peer": peer, "reason": reason}) } func RecordACLDenied(ifname, peer, policy string) { + if aclDeniedTotal == nil { + return + } aclDeniedTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "peer": peer, "policy": policy}) } func RecordCertificateExpiry(certName, ifname string, days float64) { + if certificateExpiryDays == nil { + return + } certificateExpiryDays.Record(context.Background(), days, observability.Labels{"cert_name": certName, "ifname": ifname}) } func RecordUDPPacket(ifname, packetType, direction string) { + if udpPacketsTotal == nil { + return + } udpPacketsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "type": packetType, "direction": direction}) } func RecordUDPPacketSize(ifname, packetType string, bytes float64) { + if udpPacketSizeBytes == nil { + return + } udpPacketSizeBytes.Record(context.Background(), bytes, observability.Labels{"ifname": ifname, "type": packetType}) } func RecordHolePunchEvent(ifname, result string) { + if holePunchEventsTotal == nil { + return + } holePunchEventsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "result": result}) } func RecordProxyMapping(ifname string, delta int64) { + if proxyMappingActive == nil { + return + } proxyMappingActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordSession(ifname string, delta int64) { - sessionActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) + if activeSessions == nil { + return + } + activeSessions.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordSessionRebuilt(ifname string) { + if sessionRebuiltTotal == nil { + return + } sessionRebuiltTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname}) } func RecordCommPattern(ifname string, delta int64) { + if commPatternActive == nil { + return + } commPatternActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordProxyCleanupRemoved(ifname, component string, count int64) { + if proxyCleanupRemovedTotal == nil { + return + } proxyCleanupRemovedTotal.Add(context.Background(), count, observability.Labels{"ifname": ifname, "component": component}) } func RecordProxyConnectionError(ifname, errorType string) { + if proxyConnectionErrorsTotal == nil { + return + } proxyConnectionErrorsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "error_type": errorType}) } func RecordProxyInitialMappings(ifname string, count int64) { + if proxyInitialMappingsTotal == nil { + return + } proxyInitialMappingsTotal.Record(context.Background(), count, observability.Labels{"ifname": ifname}) } func RecordProxyMappingUpdate(ifname string) { + if proxyMappingUpdatesTotal == nil { + return + } proxyMappingUpdatesTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname}) } func RecordProxyIdleCleanupDuration(ifname, component string, seconds float64) { + if proxyIdleCleanupDuration == nil { + return + } proxyIdleCleanupDuration.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "component": component}) } func RecordSNIConnection(result string) { + if sniConnectionsTotal == nil { + return + } sniConnectionsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordSNIConnectionDuration(seconds float64) { + if sniConnectionDuration == nil { + return + } sniConnectionDuration.Record(context.Background(), seconds, nil) } func RecordSNIActiveConnection(delta int64) { + if sniActiveConnections == nil { + return + } sniActiveConnections.Add(context.Background(), delta, nil) } func RecordSNIRouteCacheHit(result string) { + if sniRouteCacheHitsTotal == nil { + return + } sniRouteCacheHitsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordSNIRouteAPIRequest(result string) { + if sniRouteAPIRequestsTotal == nil { + return + } sniRouteAPIRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordSNIRouteAPILatency(seconds float64) { + if sniRouteAPILatency == nil { + return + } sniRouteAPILatency.Record(context.Background(), seconds, nil) } func RecordSNILocalOverride(hit string) { + if sniLocalOverrideTotal == nil { + return + } sniLocalOverrideTotal.Add(context.Background(), 1, observability.Labels{"hit": hit}) } func RecordSNITrustedProxyEvent(event string) { + if sniTrustedProxyEventsTotal == nil { + return + } sniTrustedProxyEventsTotal.Add(context.Background(), 1, observability.Labels{"event": event}) } func RecordSNIProxyProtocolParseError() { + if sniProxyProtocolParseErrorsTotal == nil { + return + } sniProxyProtocolParseErrorsTotal.Add(context.Background(), 1, nil) } func RecordSNIDataBytes(direction string, bytes int64) { + if sniDataBytesTotal == nil { + return + } sniDataBytesTotal.Add(context.Background(), bytes, observability.Labels{"direction": direction}) } func RecordSNITunnelTermination(reason string) { + if sniTunnelTerminationsTotal == nil { + return + } sniTunnelTerminationsTotal.Add(context.Background(), 1, observability.Labels{"reason": reason}) } func RecordHTTPRequest(endpoint, method, statusCode string) { + if httpRequestsTotal == nil { + return + } httpRequestsTotal.Add(context.Background(), 1, observability.Labels{"endpoint": endpoint, "method": method, "status_code": statusCode}) } func RecordHTTPRequestDuration(endpoint, method string, seconds float64) { + if httpRequestDuration == nil { + return + } httpRequestDuration.Record(context.Background(), seconds, observability.Labels{"endpoint": endpoint, "method": method}) } func RecordPeerOperation(operation, result string) { + if peerOperationsTotal == nil { + return + } peerOperationsTotal.Add(context.Background(), 1, observability.Labels{"operation": operation, "result": result}) } func RecordProxyMappingUpdateRequest(result string) { + if proxyMappingUpdateRequestsTotal == nil { + return + } proxyMappingUpdateRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordDestinationsUpdateRequest(result string) { + if destinationsUpdateRequestsTotal == nil { + return + } destinationsUpdateRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordRemoteConfigFetch(result string) { + if remoteConfigFetchesTotal == nil { + return + } remoteConfigFetchesTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordBandwidthReport(result string) { + if bandwidthReportsTotal == nil { + return + } bandwidthReportsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordPeerBandwidthBytes(peer, direction string, bytes int64) { + if peerBandwidthBytesTotal == nil { + return + } peerBandwidthBytesTotal.Add(context.Background(), bytes, observability.Labels{"peer": peer, "direction": direction}) } func RecordMemorySpike(severity string) { + if memorySpikeTotal == nil { + return + } memorySpikeTotal.Add(context.Background(), 1, observability.Labels{"severity": severity}) } func RecordHeapProfileWritten() { + if heapProfilesWrittenTotal == nil { + return + } heapProfilesWrittenTotal.Add(context.Background(), 1, nil) } diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 132c3fe..8c01c68 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -89,6 +89,9 @@ func TestDefaultConfig(t *testing.T) { } func TestShutdownNoInit(t *testing.T) { + // Ensure a known clean global state before testing no-init shutdown behavior. + _ = metrics.Shutdown(context.Background()) + // Shutdown without Initialize should not panic or error. if err := metrics.Shutdown(context.Background()); err != nil { t.Errorf("unexpected error: %v", err) @@ -168,6 +171,7 @@ func TestRecordRelay(t *testing.T) { body := scrape(t, h) assertContains(t, body, "gerbil_udp_packets_total") assertContains(t, body, "gerbil_proxy_mapping_active") + assertContains(t, body, "gerbil_active_sessions") } func TestRecordWireGuard(t *testing.T) { @@ -216,10 +220,10 @@ func TestRecordNetlink(t *testing.T) { metrics.RecordKernelModuleLoad("success") metrics.RecordFirewallRuleApplied("success", "INPUT") metrics.RecordActiveSession("wg0", 1) - metrics.RecordActiveProxyConnection(exampleHostname, 1) - metrics.RecordProxyRouteLookup("hit", exampleHostname) - metrics.RecordProxyTLSHandshake(exampleHostname, 0.05) - metrics.RecordProxyBytesTransmitted(exampleHostname, "tx", 1024) + metrics.RecordActiveProxyConnection(1) + metrics.RecordProxyRouteLookup("hit") + metrics.RecordProxyTLSHandshake(0.05) + metrics.RecordProxyBytesTransmitted("tx", 1024) body := scrape(t, h) assertContains(t, body, "gerbil_netlink_events_total") assertContains(t, body, "gerbil_active_sessions")