From b9261b8fea620d5254ecffee996d0ef88d4c9b1b Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 27 Feb 2026 15:45:17 -0800 Subject: [PATCH 01/25] Add optional tc --- main.go | 228 ++++++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 207 insertions(+), 21 deletions(-) diff --git a/main.go b/main.go index b723530..dde0026 100644 --- a/main.go +++ b/main.go @@ -33,15 +33,16 @@ import ( ) var ( - interfaceName string - listenAddr string - mtuInt int - lastReadings = make(map[string]PeerReading) - mu sync.Mutex - wgMu sync.Mutex // Protects WireGuard operations - notifyURL string - proxyRelay *relay.UDPProxyServer - proxySNI *proxy.SNIProxy + interfaceName string + listenAddr string + mtuInt int + lastReadings = make(map[string]PeerReading) + mu sync.Mutex + wgMu sync.Mutex // Protects WireGuard operations + notifyURL string + proxyRelay *relay.UDPProxyServer + proxySNI *proxy.SNIProxy + doTrafficShaping bool ) type WgConfig struct { @@ -151,6 +152,7 @@ func main() { localOverridesStr = os.Getenv("LOCAL_OVERRIDES") trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS") proxyProtocolStr := os.Getenv("PROXY_PROTOCOL") + doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING") if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -222,6 +224,13 @@ func main() { flag.BoolVar(&proxyProtocol, "proxy-protocol", true, "Enable PROXY protocol v1 for preserving client IP") } + if doTrafficShapingStr != "" { + doTrafficShaping = strings.ToLower(doTrafficShapingStr) == "true" + } + if doTrafficShapingStr == "" { + flag.BoolVar(&doTrafficShaping, "do-traffic-shaping", false, "Whether to set up traffic shaping rules for peers (requires tc command and root privileges)") + } + flag.Parse() logger.Init() @@ -886,17 +895,23 @@ func addPeerInternal(peer Peer) error { return fmt.Errorf("failed to parse public key: %v", err) } + logger.Debug("Adding peer %s with AllowedIPs: %v", peer.PublicKey, peer.AllowedIPs) + // parse allowed IPs into array of net.IPNet var allowedIPs []net.IPNet var wgIPs []string for _, ipStr := range peer.AllowedIPs { + logger.Debug("Parsing AllowedIP: %s", ipStr) _, ipNet, err := net.ParseCIDR(ipStr) if err != nil { + logger.Warn("Failed to parse allowed IP '%s' for peer %s: %v", ipStr, peer.PublicKey, err) return fmt.Errorf("failed to parse allowed IP: %v", err) } allowedIPs = append(allowedIPs, *ipNet) // Extract the IP address from the CIDR for relay cleanup - wgIPs = append(wgIPs, ipNet.IP.String()) + extractedIP := ipNet.IP.String() + wgIPs = append(wgIPs, extractedIP) + logger.Debug("Extracted IP %s from AllowedIP %s", extractedIP, ipStr) } peerConfig := wgtypes.PeerConfig{ @@ -912,6 +927,18 @@ func addPeerInternal(peer Peer) error { return fmt.Errorf("failed to add peer: %v", err) } + // Setup bandwidth limiting for each peer IP + if doTrafficShaping { + logger.Debug("doTrafficShaping is true, setting up bandwidth limits for %d IPs", len(wgIPs)) + for _, wgIP := range wgIPs { + if err := setupPeerBandwidthLimit(wgIP); err != nil { + logger.Warn("Failed to setup bandwidth limit for peer IP %s: %v", wgIP, err) + } + } + } else { + logger.Debug("doTrafficShaping is false, skipping bandwidth limit setup") + } + // Clear relay connections for the peer's WireGuard IPs if proxyRelay != nil { for _, wgIP := range wgIPs { @@ -956,19 +983,17 @@ func removePeerInternal(publicKey string) error { return fmt.Errorf("failed to parse public key: %v", err) } - // Get current peer info before removing to clear relay connections + // Get current peer info before removing to clear relay connections and bandwidth limits var wgIPs []string - if proxyRelay != nil { - device, err := wgClient.Device(interfaceName) - if err == nil { - for _, peer := range device.Peers { - if peer.PublicKey.String() == publicKey { - // Extract WireGuard IPs from this peer's allowed IPs - for _, allowedIP := range peer.AllowedIPs { - wgIPs = append(wgIPs, allowedIP.IP.String()) - } - break + device, err := wgClient.Device(interfaceName) + if err == nil { + for _, peer := range device.Peers { + if peer.PublicKey.String() == publicKey { + // Extract WireGuard IPs from this peer's allowed IPs + for _, allowedIP := range peer.AllowedIPs { + wgIPs = append(wgIPs, allowedIP.IP.String()) } + break } } } @@ -986,6 +1011,15 @@ func removePeerInternal(publicKey string) error { return fmt.Errorf("failed to remove peer: %v", err) } + // Remove bandwidth limits for each peer IP + if doTrafficShaping { + for _, wgIP := range wgIPs { + if err := removePeerBandwidthLimit(wgIP); err != nil { + logger.Warn("Failed to remove bandwidth limit for peer IP %s: %v", wgIP, err) + } + } + } + // Clear relay connections for the peer's WireGuard IPs if proxyRelay != nil { for _, wgIP := range wgIPs { @@ -1315,3 +1349,155 @@ func monitorMemory(limit uint64) { time.Sleep(5 * time.Second) } } + +// setupPeerBandwidthLimit sets up TC (Traffic Control) to limit bandwidth for a specific peer IP +// Currently hardcoded to 20 Mbps per peer +func setupPeerBandwidthLimit(peerIP string) error { + logger.Debug("setupPeerBandwidthLimit called for peer IP: %s", peerIP) + const bandwidthLimit = "50mbit" // 50 Mbps limit per peer + + // Parse the IP to get just the IP address (strip any CIDR notation if present) + ip := peerIP + if strings.Contains(peerIP, "/") { + parsedIP, _, err := net.ParseCIDR(peerIP) + if err != nil { + return fmt.Errorf("failed to parse peer IP: %v", err) + } + ip = parsedIP.String() + } + + // First, ensure we have a root qdisc on the interface (HTB - Hierarchical Token Bucket) + // Check if qdisc already exists + cmd := exec.Command("tc", "qdisc", "show", "dev", interfaceName) + output, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to check qdisc: %v, output: %s", err, string(output)) + } + + // If no HTB qdisc exists, create one + if !strings.Contains(string(output), "htb") { + cmd = exec.Command("tc", "qdisc", "add", "dev", interfaceName, "root", "handle", "1:", "htb", "default", "9999") + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add root qdisc: %v, output: %s", err, string(output)) + } + logger.Info("Created HTB root qdisc on %s", interfaceName) + } + + // Generate a unique class ID based on the IP address + // We'll use the last octet of the IP as part of the class ID + ipParts := strings.Split(ip, ".") + if len(ipParts) != 4 { + return fmt.Errorf("invalid IPv4 address: %s", ip) + } + lastOctet := ipParts[3] + classID := fmt.Sprintf("1:%s", lastOctet) + logger.Debug("Generated class ID %s for peer IP %s", classID, ip) + + // Create a class for this peer with bandwidth limit + cmd = exec.Command("tc", "class", "add", "dev", interfaceName, "parent", "1:", "classid", classID, + "htb", "rate", bandwidthLimit, "ceil", bandwidthLimit) + if output, err := cmd.CombinedOutput(); err != nil { + logger.Debug("tc class add failed for %s: %v, output: %s", ip, err, string(output)) + // If class already exists, try to replace it + if strings.Contains(string(output), "File exists") { + cmd = exec.Command("tc", "class", "replace", "dev", interfaceName, "parent", "1:", "classid", classID, + "htb", "rate", bandwidthLimit, "ceil", bandwidthLimit) + if output, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to replace class: %v, output: %s", err, string(output)) + } + logger.Debug("Successfully replaced existing class %s for peer IP %s", classID, ip) + } else { + return fmt.Errorf("failed to add class: %v, output: %s", err, string(output)) + } + } else { + logger.Debug("Successfully added new class %s for peer IP %s", classID, ip) + } + + // Add a filter to match traffic from this peer IP (ingress) + cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:", + "prio", "1", "u32", "match", "ip", "src", ip, "flowid", classID) + if output, err := cmd.CombinedOutput(); err != nil { + // If filter fails, log but don't fail the peer addition + logger.Warn("Failed to add ingress filter for peer IP %s: %v, output: %s", ip, err, string(output)) + } + + // Add a filter to match traffic to this peer IP (egress) + cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:", + "prio", "1", "u32", "match", "ip", "dst", ip, "flowid", classID) + if output, err := cmd.CombinedOutput(); err != nil { + // If filter fails, log but don't fail the peer addition + logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output)) + } + + logger.Info("Setup bandwidth limit of %s for peer IP %s (class %s)", bandwidthLimit, ip, classID) + return nil +} + +// removePeerBandwidthLimit removes TC rules for a specific peer IP +func removePeerBandwidthLimit(peerIP string) error { + // Parse the IP to get just the IP address + ip := peerIP + if strings.Contains(peerIP, "/") { + parsedIP, _, err := net.ParseCIDR(peerIP) + if err != nil { + return fmt.Errorf("failed to parse peer IP: %v", err) + } + ip = parsedIP.String() + } + + // Generate the class ID based on the IP + ipParts := strings.Split(ip, ".") + if len(ipParts) != 4 { + return fmt.Errorf("invalid IPv4 address: %s", ip) + } + lastOctet := ipParts[3] + classID := fmt.Sprintf("1:%s", lastOctet) + + // Remove filters for this IP + // List all filters to find the ones for this class + cmd := exec.Command("tc", "filter", "show", "dev", interfaceName, "parent", "1:") + output, err := cmd.CombinedOutput() + if err != nil { + logger.Warn("Failed to list filters for peer IP %s: %v, output: %s", ip, err, string(output)) + } else { + // Parse the output to find filter handles that match this classID + // The output format includes lines like: + // filter parent 1: protocol ip pref 1 u32 chain 0 fh 800::800 order 2048 key ht 800 bkt 0 flowid 1:4 + lines := strings.Split(string(output), "\n") + for _, line := range lines { + // Look for lines containing our flowid (classID) + if strings.Contains(line, "flowid "+classID) && strings.Contains(line, "fh ") { + // Extract handle (format: fh 800::800) + parts := strings.Fields(line) + var handle string + for j, part := range parts { + if part == "fh" && j+1 < len(parts) { + handle = parts[j+1] + break + } + } + if handle != "" { + // Delete this filter using the handle + delCmd := exec.Command("tc", "filter", "del", "dev", interfaceName, "parent", "1:", "handle", handle, "prio", "1", "u32") + if delOutput, delErr := delCmd.CombinedOutput(); delErr != nil { + logger.Debug("Failed to delete filter handle %s for peer IP %s: %v, output: %s", handle, ip, delErr, string(delOutput)) + } else { + logger.Debug("Deleted filter handle %s for peer IP %s", handle, ip) + } + } + } + } + } + + // Remove the class + cmd = exec.Command("tc", "class", "del", "dev", interfaceName, "classid", classID) + if output, err := cmd.CombinedOutput(); err != nil { + // It's okay if the class doesn't exist + if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") { + logger.Warn("Failed to remove class for peer IP %s: %v, output: %s", ip, err, string(output)) + } + } + + logger.Info("Removed bandwidth limit for peer IP %s (class %s)", ip, classID) + return nil +} From 7985f97eb65d6f0289c32569402a70a5f6ff11b8 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 12 Mar 2026 12:54:02 +0000 Subject: [PATCH 02/25] perf(relay): scale packet workers and queue depth for throughput --- relay/relay.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index 22aff76..bc1a6a6 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "runtime" "sync" "time" @@ -164,7 +165,7 @@ func NewUDPProxyServer(parentCtx context.Context, addr, serverURL string, privat addr: addr, serverURL: serverURL, privateKey: privateKey, - packetChan: make(chan Packet, 1000), + packetChan: make(chan Packet, 50000), // Increased from 1000 to handle high throughput ReachableAt: reachableAt, ctx: ctx, cancel: cancel, @@ -189,8 +190,13 @@ func (s *UDPProxyServer) Start() error { s.conn = conn logger.Info("UDP server listening on %s", s.addr) - // Start a fixed number of worker goroutines. - workerCount := 10 // TODO: Make this configurable or pick it better! + // Start worker goroutines based on CPU cores for better parallelism + // At high throughput (160+ Mbps), we need many workers to avoid bottlenecks + workerCount := runtime.NumCPU() * 10 + if workerCount < 20 { + workerCount = 20 // Minimum 20 workers + } + logger.Info("Starting %d packet workers (CPUs: %d)", workerCount, runtime.NumCPU()) for i := 0; i < workerCount; i++ { go s.packetWorker() } From b118fef2654c57be8870d772d629fa85040d720f Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 12 Mar 2026 12:54:59 +0000 Subject: [PATCH 03/25] perf(relay): cache resolved UDP destinations with TTL --- relay/relay.go | 48 ++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 4 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index 22aff76..190d077 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -153,6 +153,9 @@ type UDPProxyServer struct { // Communication pattern tracking for rebuilding sessions // Key format: "clientIP:clientPort-destIP:destPort" commPatterns sync.Map + // Cache for resolved UDP addresses to avoid per-packet DNS lookups + // Key: "ip:port" string, Value: *net.UDPAddr + addrCache sync.Map // ReachableAt is the URL where this server can be reached ReachableAt string } @@ -416,6 +419,43 @@ func extractWireGuardIndices(packet []byte) (uint32, uint32, bool) { return 0, 0, false } +// cachedAddr holds a resolved UDP address with TTL +type cachedAddr struct { + addr *net.UDPAddr + expiresAt time.Time +} + +// addrCacheTTL is how long resolved addresses are cached before re-resolving +const addrCacheTTL = 5 * time.Minute + +// getCachedAddr returns a cached UDP address or resolves and caches it. +// This avoids per-packet DNS lookups which are a major throughput bottleneck. +func (s *UDPProxyServer) getCachedAddr(ip string, port int) (*net.UDPAddr, error) { + key := fmt.Sprintf("%s:%d", ip, port) + + // Check cache first + if cached, ok := s.addrCache.Load(key); ok { + entry := cached.(*cachedAddr) + if time.Now().Before(entry.expiresAt) { + return entry.addr, nil + } + // Cache expired, delete and re-resolve + s.addrCache.Delete(key) + } + + // Resolve and cache + addr, err := net.ResolveUDPAddr("udp", key) + if err != nil { + return nil, err + } + + s.addrCache.Store(key, &cachedAddr{ + addr: addr, + expiresAt: time.Now().Add(addrCacheTTL), + }) + return addr, nil +} + // Updated to handle multi-peer WireGuard communication func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UDPAddr) { if len(packet) == 0 { @@ -450,7 +490,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD logger.Debug("Forwarding handshake initiation from %s (sender index: %d) to peers %v", remoteAddr, senderIndex, proxyMapping.Destinations) for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue @@ -486,7 +526,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // Forward the response to the original sender for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue @@ -543,7 +583,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // No known session, fall back to forwarding to all peers logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex) for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue @@ -571,7 +611,7 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // Forward to all peers for _, dest := range proxyMapping.Destinations { - destAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", dest.DestinationIP, dest.DestinationPort)) + destAddr, err := s.getCachedAddr(dest.DestinationIP, dest.DestinationPort) if err != nil { logger.Error("Failed to resolve destination address: %v", err) continue From abc744c647e7267f9bb161ffb1bf1a410a8ea3c5 Mon Sep 17 00:00:00 2001 From: Laurence Date: Thu, 12 Mar 2026 12:55:49 +0000 Subject: [PATCH 04/25] perf(relay): index WireGuard sessions by receiver index --- relay/relay.go | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index 22aff76..402c696 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -150,6 +150,9 @@ type UDPProxyServer struct { // Session tracking for WireGuard peers // Key format: "senderIndex:receiverIndex" wgSessions sync.Map + // Session index for O(1) lookup by receiver index + // Key: receiverIndex (uint32), Value: *WireGuardSession + sessionsByReceiverIndex sync.Map // Communication pattern tracking for rebuilding sessions // Key format: "clientIP:clientPort-destIP:destPort" commPatterns sync.Map @@ -477,12 +480,15 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex) // Store the session information - s.wgSessions.Store(sessionKey, &WireGuardSession{ + session := &WireGuardSession{ ReceiverIndex: receiverIndex, SenderIndex: senderIndex, DestAddr: remoteAddr, LastSeen: time.Now(), - }) + } + s.wgSessions.Store(sessionKey, session) + // Also index by sender index for O(1) lookup in transport data path + s.sessionsByReceiverIndex.Store(senderIndex, session) // Forward the response to the original sender for _, dest := range proxyMapping.Destinations { @@ -508,21 +514,15 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD // Data packet: forward only to the established session peer // logger.Debug("Received transport data with receiver index %d from %s", receiverIndex, remoteAddr) - // Look up the session based on the receiver index + // Look up the session based on the receiver index - O(1) lookup instead of O(n) Range var destAddr *net.UDPAddr - // First check for existing sessions to see if we know where to send this packet - s.wgSessions.Range(func(k, v interface{}) bool { - session := v.(*WireGuardSession) - // Check if session matches (read lock for check) - if session.GetSenderIndex() == receiverIndex { - // Found matching session - get dest addr and update last seen - destAddr = session.GetDestAddr() - session.UpdateLastSeen() - return false // stop iteration - } - return true // continue iteration - }) + // Fast path: direct index lookup by receiver index + if sessionObj, ok := s.sessionsByReceiverIndex.Load(receiverIndex); ok { + session := sessionObj.(*WireGuardSession) + destAddr = session.GetDestAddr() + session.UpdateLastSeen() + } if destAddr != nil { // We found a specific peer to forward to @@ -634,12 +634,15 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse { // Store the session mapping for the handshake response sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex) - s.wgSessions.Store(sessionKey, &WireGuardSession{ + session := &WireGuardSession{ ReceiverIndex: receiverIndex, SenderIndex: senderIndex, DestAddr: destAddr, LastSeen: time.Now(), - }) + } + s.wgSessions.Store(sessionKey, session) + // Also index by sender index for O(1) lookup + s.sessionsByReceiverIndex.Store(senderIndex, session) logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String()) } else if ok && buffer[0] == WireGuardMessageTypeTransportData { // Track communication pattern for session rebuilding (reverse direction) From c7d9c72f2937c3d8594eaf58694d7dd8bc55b6d9 Mon Sep 17 00:00:00 2001 From: Laurence Date: Fri, 13 Mar 2026 15:28:04 +0000 Subject: [PATCH 05/25] Add HTTP client reuse and buffer pooling for performance - Add reusable HTTP client with connection pooling for API requests - Add sync.Pool for 32KB buffers used in connection piping - Clear buffers before returning to pool to prevent data leakage - Reduces GC pressure and improves throughput under load --- proxy/proxy.go | 49 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index f29878e..03579d3 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -69,6 +69,12 @@ type SNIProxy struct { // Trusted upstream proxies that can send PROXY protocol trustedUpstreams map[string]struct{} + + // Reusable HTTP client for API requests + httpClient *http.Client + + // Buffer pool for connection piping + bufferPool *sync.Pool } type activeTunnel struct { @@ -374,6 +380,20 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo localOverrides: overridesMap, activeTunnels: make(map[string]*activeTunnel), trustedUpstreams: trustedMap, + httpClient: &http.Client{ + Timeout: 5 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + }, + }, + bufferPool: &sync.Pool{ + New: func() interface{} { + buf := make([]byte, 32*1024) + return &buf + }, + }, } return proxy, nil @@ -681,9 +701,8 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { } req.Header.Set("Content-Type", "application/json") - // Make HTTP request - client := &http.Client{Timeout: 5 * time.Second} - resp, err := client.Do(req) + // Make HTTP request using reusable client + resp, err := p.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("API request failed: %w", err) } @@ -773,9 +792,15 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) defer wg.Done() defer closeConns() - // Use a large buffer for better performance - buf := make([]byte, 32*1024) - _, err := io.CopyBuffer(targetConn, clientReader, buf) + // Get buffer from pool and return when done + bufPtr := p.bufferPool.Get().(*[]byte) + defer func() { + // Clear buffer before returning to pool to prevent data leakage + clear(*bufPtr) + p.bufferPool.Put(bufPtr) + }() + + _, err := io.CopyBuffer(targetConn, clientReader, *bufPtr) if err != nil && err != io.EOF { logger.Debug("Copy client->target error: %v", err) } @@ -786,9 +811,15 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) defer wg.Done() defer closeConns() - // Use a large buffer for better performance - buf := make([]byte, 32*1024) - _, err := io.CopyBuffer(clientConn, targetConn, buf) + // Get buffer from pool and return when done + bufPtr := p.bufferPool.Get().(*[]byte) + defer func() { + // Clear buffer before returning to pool to prevent data leakage + clear(*bufPtr) + p.bufferPool.Put(bufPtr) + }() + + _, err := io.CopyBuffer(clientConn, targetConn, *bufPtr) if err != nil && err != io.EOF { logger.Debug("Copy target->client error: %v", err) } From fcead8cc15c8f225a9b46138ea48ca6cdbff7b87 Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 20 Mar 2026 16:02:58 -0700 Subject: [PATCH 06/25] Add rate limit to hole punch --- relay/relay.go | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/relay/relay.go b/relay/relay.go index 22aff76..8fabbff 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -118,6 +118,13 @@ type Packet struct { n int } +// holePunchRateLimitEntry tracks hole punch message counts within a sliding 1-second window. +type holePunchRateLimitEntry struct { + mu sync.Mutex + count int + windowStart time.Time +} + // WireGuard message types const ( WireGuardMessageTypeHandshakeInitiation = 1 @@ -153,6 +160,8 @@ type UDPProxyServer struct { // Communication pattern tracking for rebuilding sessions // Key format: "clientIP:clientPort-destIP:destPort" commPatterns sync.Map + // Rate limiter for encrypted hole punch messages, keyed by "ip:port" + holePunchRateLimiter sync.Map // ReachableAt is the URL where this server can be reached ReachableAt string } @@ -210,6 +219,9 @@ func (s *UDPProxyServer) Start() error { // Start the communication pattern cleanup routine go s.cleanupIdleCommunicationPatterns() + // Start the hole punch rate limiter cleanup routine + go s.cleanupHolePunchRateLimiter() + return nil } @@ -272,6 +284,27 @@ func (s *UDPProxyServer) packetWorker() { // Process as a WireGuard packet. s.handleWireGuardPacket(packet.data, packet.remoteAddr) } else { + // Rate limit: allow at most 2 hole punch messages per IP:Port per second + rateLimitKey := packet.remoteAddr.String() + entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{ + windowStart: time.Now(), + }) + rlEntry := entryVal.(*holePunchRateLimitEntry) + rlEntry.mu.Lock() + now := time.Now() + if now.Sub(rlEntry.windowStart) >= time.Second { + rlEntry.count = 0 + rlEntry.windowStart = now + } + rlEntry.count++ + allowed := rlEntry.count <= 2 + rlEntry.mu.Unlock() + if !allowed { + logger.Debug("Rate limiting hole punch message from %s", rateLimitKey) + bufferPool.Put(packet.data[:1500]) + continue + } + // Process as an encrypted hole punch message var encMsg EncryptedHolePunchMessage if err := json.Unmarshal(packet.data, &encMsg); err != nil { @@ -1030,6 +1063,30 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { } // cleanupIdleCommunicationPatterns periodically removes idle communication patterns +// cleanupHolePunchRateLimiter periodically evicts stale rate limit entries to prevent unbounded growth. +func (s *UDPProxyServer) cleanupHolePunchRateLimiter() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + for { + select { + case <-ticker.C: + now := time.Now() + s.holePunchRateLimiter.Range(func(key, value interface{}) bool { + rlEntry := value.(*holePunchRateLimitEntry) + rlEntry.mu.Lock() + stale := now.Sub(rlEntry.windowStart) > 10*time.Second + rlEntry.mu.Unlock() + if stale { + s.holePunchRateLimiter.Delete(key) + } + return true + }) + case <-s.ctx.Done(): + return + } + } +} + func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { ticker := time.NewTicker(10 * time.Minute) defer ticker.Stop() From 40da38708cd71fa266d146c3757d995d32b76ecc Mon Sep 17 00:00:00 2001 From: Owen Date: Fri, 20 Mar 2026 16:11:10 -0700 Subject: [PATCH 07/25] Update logging --- relay/relay.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index af18258..0ab5930 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -309,7 +309,7 @@ func (s *UDPProxyServer) packetWorker() { allowed := rlEntry.count <= 2 rlEntry.mu.Unlock() if !allowed { - logger.Debug("Rate limiting hole punch message from %s", rateLimitKey) + // logger.Debug("Rate limiting hole punch message from %s", rateLimitKey) bufferPool.Put(packet.data[:1500]) continue } @@ -333,7 +333,7 @@ func (s *UDPProxyServer) packetWorker() { // This appears to be an encrypted message decryptedData, err := s.decryptMessage(encMsg) if err != nil { - logger.Error("Failed to decrypt message: %v", err) + // logger.Error("Failed to decrypt message: %v", err) // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue From a3862260c946d6924a75f3a0c9b6810bb8d0b091 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 31 Mar 2026 20:35:05 -0700 Subject: [PATCH 08/25] Add var for b limit --- main.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/main.go b/main.go index dde0026..cc86355 100644 --- a/main.go +++ b/main.go @@ -43,6 +43,7 @@ var ( proxyRelay *relay.UDPProxyServer proxySNI *proxy.SNIProxy doTrafficShaping bool + bandwidthLimit string ) type WgConfig struct { @@ -153,6 +154,7 @@ func main() { trustedUpstreamsStr = os.Getenv("TRUSTED_UPSTREAMS") proxyProtocolStr := os.Getenv("PROXY_PROTOCOL") doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING") + bandwidthLimitStr := os.Getenv("BANDWIDTH_LIMIT") if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -231,6 +233,13 @@ func main() { flag.BoolVar(&doTrafficShaping, "do-traffic-shaping", false, "Whether to set up traffic shaping rules for peers (requires tc command and root privileges)") } + if bandwidthLimitStr != "" { + bandwidthLimit = bandwidthLimitStr + } + if bandwidthLimitStr == "" { + flag.StringVar(&bandwidthLimit, "bandwidth-limit", "50mbit", "Bandwidth limit per peer for traffic shaping (e.g. 50mbit, 1gbit)") + } + flag.Parse() logger.Init() @@ -1351,10 +1360,10 @@ func monitorMemory(limit uint64) { } // setupPeerBandwidthLimit sets up TC (Traffic Control) to limit bandwidth for a specific peer IP -// Currently hardcoded to 20 Mbps per peer +// Bandwidth limit is configurable via the --bandwidth-limit flag or BANDWIDTH_LIMIT env var (default: 50mbit) func setupPeerBandwidthLimit(peerIP string) error { logger.Debug("setupPeerBandwidthLimit called for peer IP: %s", peerIP) - const bandwidthLimit = "50mbit" // 50 Mbps limit per peer + // Parse the IP to get just the IP address (strip any CIDR notation if present) ip := peerIP From b57574cc4bb098301f8856ba066f572e33fc20a9 Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 31 Mar 2026 21:56:41 -0700 Subject: [PATCH 09/25] IFB ingress limiting --- main.go | 197 +++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 181 insertions(+), 16 deletions(-) diff --git a/main.go b/main.go index cc86355..695de55 100644 --- a/main.go +++ b/main.go @@ -44,6 +44,7 @@ var ( proxySNI *proxy.SNIProxy doTrafficShaping bool bandwidthLimit string + ifbName string // IFB device name for ingress traffic shaping ) type WgConfig struct { @@ -242,6 +243,12 @@ func main() { flag.Parse() + // Derive IFB device name from the WireGuard interface name (Linux limit: 15 chars) + ifbName = "ifb_" + interfaceName + if len(ifbName) > 15 { + ifbName = ifbName[:15] + } + logger.Init() logger.GetLogger().SetLevel(parseLogLevel(logLevel)) @@ -353,6 +360,13 @@ func main() { logger.Fatal("Failed to ensure WireGuard interface: %v", err) } + // Set up IFB device for bidirectional ingress/egress traffic shaping if enabled + if doTrafficShaping { + if err := ensureIFBDevice(); err != nil { + logger.Fatal("Failed to ensure IFB device for traffic shaping: %v", err) + } + } + // Ensure the WireGuard peers exist ensureWireguardPeers(wgconfig.Peers) @@ -1359,12 +1373,92 @@ func monitorMemory(limit uint64) { } } +// ensureIFBDevice creates and configures the IFB (Intermediate Functional Block) device used to +// shape ingress traffic on the WireGuard interface. Linux TC qdiscs only control egress by default; +// the IFB trick redirects all ingress packets to a virtual device so HTB shaping can be applied +// there, and the packets are transparently re-injected into the kernel network stack afterwards. +// This is completely invisible to sockets/applications (including a reverse proxy on the host). +func ensureIFBDevice() error { + // Check if the ifb kernel module is loaded (works inside containers too) + if _, err := os.Stat("/sys/module/ifb"); os.IsNotExist(err) { + logger.Warn("IFB module not loaded, skipping IFB setup and ingress traffic shaping") + return nil + } + + // Create the IFB device if it does not already exist + _, err := netlink.LinkByName(ifbName) + if err != nil { + if _, ok := err.(netlink.LinkNotFoundError); ok { + cmd := exec.Command("ip", "link", "add", ifbName, "type", "ifb") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to create IFB device %s: %v, output: %s", ifbName, err, string(out)) + } + logger.Info("Created IFB device %s", ifbName) + } else { + return fmt.Errorf("failed to look up IFB device %s: %v", ifbName, err) + } + } else { + logger.Info("IFB device %s already exists", ifbName) + } + + // Bring the IFB device up + cmd := exec.Command("ip", "link", "set", "dev", ifbName, "up") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to bring up IFB device %s: %v, output: %s", ifbName, err, string(out)) + } + + // Attach an ingress qdisc to the WireGuard interface if one is not already present + cmd = exec.Command("tc", "qdisc", "show", "dev", interfaceName) + out, err := cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to query qdiscs on %s: %v", interfaceName, err) + } + if !strings.Contains(string(out), "ingress") { + cmd = exec.Command("tc", "qdisc", "add", "dev", interfaceName, "handle", "ffff:", "ingress") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add ingress qdisc to %s: %v, output: %s", interfaceName, err, string(out)) + } + logger.Info("Added ingress qdisc to %s", interfaceName) + } + + // Add a catch-all filter that redirects every ingress packet from wg0 to the IFB device. + // Per-peer rate limiting then happens on ifb0's egress HTB qdisc (handle 2:). + cmd = exec.Command("tc", "filter", "show", "dev", interfaceName, "parent", "ffff:") + out, err = cmd.CombinedOutput() + if err != nil || !strings.Contains(string(out), ifbName) { + cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, + "parent", "ffff:", "protocol", "ip", + "u32", "match", "u32", "0", "0", + "action", "mirred", "egress", "redirect", "dev", ifbName) + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add ingress redirect filter on %s: %v, output: %s", interfaceName, err, string(out)) + } + logger.Info("Added ingress redirect filter: %s -> %s", interfaceName, ifbName) + } + + // Ensure an HTB root qdisc exists on the IFB device (handle 2:) for per-peer shaping + cmd = exec.Command("tc", "qdisc", "show", "dev", ifbName) + out, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to query qdiscs on %s: %v", ifbName, err) + } + if !strings.Contains(string(out), "htb") { + cmd = exec.Command("tc", "qdisc", "add", "dev", ifbName, "root", "handle", "2:", "htb", "default", "9999") + if out, err := cmd.CombinedOutput(); err != nil { + return fmt.Errorf("failed to add HTB qdisc to %s: %v, output: %s", ifbName, err, string(out)) + } + logger.Info("Added HTB root qdisc (handle 2:) to IFB device %s", ifbName) + } + + logger.Info("IFB device %s ready for ingress traffic shaping", ifbName) + return nil +} + // setupPeerBandwidthLimit sets up TC (Traffic Control) to limit bandwidth for a specific peer IP // Bandwidth limit is configurable via the --bandwidth-limit flag or BANDWIDTH_LIMIT env var (default: 50mbit) func setupPeerBandwidthLimit(peerIP string) error { logger.Debug("setupPeerBandwidthLimit called for peer IP: %s", peerIP) - // Parse the IP to get just the IP address (strip any CIDR notation if present) ip := peerIP if strings.Contains(peerIP, "/") { @@ -1422,23 +1516,50 @@ func setupPeerBandwidthLimit(peerIP string) error { logger.Debug("Successfully added new class %s for peer IP %s", classID, ip) } - // Add a filter to match traffic from this peer IP (ingress) - cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:", - "prio", "1", "u32", "match", "ip", "src", ip, "flowid", classID) - if output, err := cmd.CombinedOutput(); err != nil { - // If filter fails, log but don't fail the peer addition - logger.Warn("Failed to add ingress filter for peer IP %s: %v, output: %s", ip, err, string(output)) - } - - // Add a filter to match traffic to this peer IP (egress) + // Add a filter to match traffic to this peer IP on wg0 egress (peer's download) cmd = exec.Command("tc", "filter", "add", "dev", interfaceName, "protocol", "ip", "parent", "1:", "prio", "1", "u32", "match", "ip", "dst", ip, "flowid", classID) if output, err := cmd.CombinedOutput(); err != nil { - // If filter fails, log but don't fail the peer addition logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output)) } + + // Set up ingress shaping on the IFB device (peer's upload / ingress on wg0). + // All wg0 ingress is redirected to ifb0 by ensureIFBDevice; we add a per-peer + // class + src filter here so each peer gets its own independent rate limit. + ifbClassID := fmt.Sprintf("2:%s", lastOctet) - logger.Info("Setup bandwidth limit of %s for peer IP %s (class %s)", bandwidthLimit, ip, classID) + // Check if the ifb kernel module is loaded (works inside containers too) + if _, err := os.Stat("/sys/module/ifb"); os.IsNotExist(err) { + logger.Warn("IFB module not loaded, skipping IFB setup and ingress traffic shaping.") + logger.Info("Setup bandwidth limit of %s for peer IP %s (egress class %s, ingress class %s)", bandwidthLimit, ip, classID, ifbClassID) + return nil + } + + cmd = exec.Command("tc", "class", "add", "dev", ifbName, "parent", "2:", "classid", ifbClassID, + "htb", "rate", bandwidthLimit, "ceil", bandwidthLimit) + if output, err := cmd.CombinedOutput(); err != nil { + if strings.Contains(string(output), "File exists") { + cmd = exec.Command("tc", "class", "replace", "dev", ifbName, "parent", "2:", "classid", ifbClassID, + "htb", "rate", bandwidthLimit, "ceil", bandwidthLimit) + if output, err := cmd.CombinedOutput(); err != nil { + logger.Warn("Failed to replace IFB class for peer IP %s: %v, output: %s", ip, err, string(output)) + } else { + logger.Debug("Replaced existing IFB class %s for peer IP %s", ifbClassID, ip) + } + } else { + logger.Warn("Failed to add IFB class for peer IP %s: %v, output: %s", ip, err, string(output)) + } + } else { + logger.Debug("Added IFB class %s for peer IP %s", ifbClassID, ip) + } + + cmd = exec.Command("tc", "filter", "add", "dev", ifbName, "protocol", "ip", "parent", "2:", + "prio", "1", "u32", "match", "ip", "src", ip, "flowid", ifbClassID) + if output, err := cmd.CombinedOutput(); err != nil { + logger.Warn("Failed to add IFB ingress filter for peer IP %s: %v, output: %s", ip, err, string(output)) + } + + logger.Info("Setup bandwidth limit of %s for peer IP %s (egress class %s, ingress class %s)", bandwidthLimit, ip, classID, ifbClassID) return nil } @@ -1498,15 +1619,59 @@ func removePeerBandwidthLimit(peerIP string) error { } } - // Remove the class + // Remove the egress class on wg0 cmd = exec.Command("tc", "class", "del", "dev", interfaceName, "classid", classID) if output, err := cmd.CombinedOutput(); err != nil { - // It's okay if the class doesn't exist if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") { - logger.Warn("Failed to remove class for peer IP %s: %v, output: %s", ip, err, string(output)) + logger.Warn("Failed to remove egress class for peer IP %s: %v, output: %s", ip, err, string(output)) + } + } + + // Remove the ingress class and filters on the IFB device + ifbClassID := fmt.Sprintf("2:%s", lastOctet) + + // Check if the ifb kernel module is loaded (works inside containers too) + if _, err := os.Stat("/sys/module/ifb"); os.IsNotExist(err) { + logger.Warn("IFB module not loaded, skipping IFB setup and ingress traffic shaping") + logger.Info("Removed bandwidth limit for peer IP %s (egress class %s, ingress class %s)", ip, classID, ifbClassID) + return nil + } + + cmd = exec.Command("tc", "filter", "show", "dev", ifbName, "parent", "2:") + output, err = cmd.CombinedOutput() + if err != nil { + logger.Warn("Failed to list IFB filters for peer IP %s: %v, output: %s", ip, err, string(output)) + } else { + lines := strings.Split(string(output), "\n") + for _, line := range lines { + if strings.Contains(line, "flowid "+ifbClassID) && strings.Contains(line, "fh ") { + parts := strings.Fields(line) + var handle string + for j, part := range parts { + if part == "fh" && j+1 < len(parts) { + handle = parts[j+1] + break + } + } + if handle != "" { + delCmd := exec.Command("tc", "filter", "del", "dev", ifbName, "parent", "2:", "handle", handle, "prio", "1", "u32") + if delOutput, delErr := delCmd.CombinedOutput(); delErr != nil { + logger.Debug("Failed to delete IFB filter handle %s for peer IP %s: %v, output: %s", handle, ip, delErr, string(delOutput)) + } else { + logger.Debug("Deleted IFB filter handle %s for peer IP %s", handle, ip) + } + } + } } } - logger.Info("Removed bandwidth limit for peer IP %s (class %s)", ip, classID) + cmd = exec.Command("tc", "class", "del", "dev", ifbName, "classid", ifbClassID) + if output, err := cmd.CombinedOutput(); err != nil { + if !strings.Contains(string(output), "No such file or directory") && !strings.Contains(string(output), "Cannot find") { + logger.Warn("Failed to remove IFB class for peer IP %s: %v, output: %s", ip, err, string(output)) + } + } + + logger.Info("Removed bandwidth limit for peer IP %s (egress class %s, ingress class %s)", ip, classID, ifbClassID) return nil } From f322b4c92165adc031522e36533d2dffb70b1027 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Fri, 3 Apr 2026 15:57:47 +0200 Subject: [PATCH 10/25] Add OpenTelemetry and Prometheus metrics infrastructure --- README.md | 18 + docs/observability.md | 269 ++++++++++ examples/otel-collector-config.yaml | 46 ++ examples/prometheus.yml | 24 + go.mod | 37 +- go.sum | 89 ++- internal/metrics/metrics.go | 506 ++++++++++++++++++ internal/metrics/metrics_test.go | 258 +++++++++ internal/observability/config.go | 119 ++++ internal/observability/metrics.go | 152 ++++++ internal/observability/metrics_test.go | 198 +++++++ internal/observability/noop.go | 71 +++ internal/observability/noop_test.go | 67 +++ internal/observability/otel/backend.go | 210 ++++++++ internal/observability/otel/backend_test.go | 141 +++++ internal/observability/otel/exporter.go | 50 ++ internal/observability/otel/resource.go | 25 + internal/observability/prometheus/backend.go | 185 +++++++ .../observability/prometheus/backend_test.go | 173 ++++++ 19 files changed, 2623 insertions(+), 15 deletions(-) create mode 100644 docs/observability.md create mode 100644 examples/otel-collector-config.yaml create mode 100644 examples/prometheus.yml create mode 100644 internal/metrics/metrics.go create mode 100644 internal/metrics/metrics_test.go create mode 100644 internal/observability/config.go create mode 100644 internal/observability/metrics.go create mode 100644 internal/observability/metrics_test.go create mode 100644 internal/observability/noop.go create mode 100644 internal/observability/noop_test.go create mode 100644 internal/observability/otel/backend.go create mode 100644 internal/observability/otel/backend_test.go create mode 100644 internal/observability/otel/exporter.go create mode 100644 internal/observability/otel/resource.go create mode 100644 internal/observability/prometheus/backend.go create mode 100644 internal/observability/prometheus/backend_test.go diff --git a/README.md b/README.md index 85c9693..e403d88 100644 --- a/README.md +++ b/README.md @@ -40,6 +40,24 @@ The PROXY protocol allows downstream proxies to know the real client IP address In single node (self hosted) Pangolin deployments this can be bypassed by using port 443:443 to route to Traefik instead of the SNI proxy at 8443. +### Observability with OpenTelemetry + +Gerbil includes comprehensive OpenTelemetry metrics instrumentation for monitoring and observability. Metrics can be exported via: + +- **Prometheus**: Pull-based metrics at the `/metrics` endpoint (enabled by default) +- **OTLP**: Push-based metrics to any OpenTelemetry-compatible collector + +Key metrics include: + +- WireGuard interface and peer status +- Bandwidth usage per peer +- Active relay sessions and proxy connections +- Handshake success/failure rates +- Route lookup cache hit/miss ratios +- Go runtime metrics (GC, goroutines, memory) + +See [docs/observability.md](docs/observability.md) for complete documentation, metrics reference, and examples. + ## CLI Args Important: diff --git a/docs/observability.md b/docs/observability.md new file mode 100644 index 0000000..13cb038 --- /dev/null +++ b/docs/observability.md @@ -0,0 +1,269 @@ + +# Gerbil Observability Architecture + +This document describes the metrics subsystem for Gerbil, explains the design +decisions, and shows how to configure each backend. + +--- + +## Architecture Overview + +Gerbil's metrics subsystem uses a **pluggable backend** design: + +```text +main.go ─── internal/metrics ─── internal/observability ─── backend + (facade) (interface) Prometheus + OR OTel/OTLP + OR Noop (disabled) +``` + +Application code (main, relay, proxy) calls only the `metrics.Record*` +functions in `internal/metrics`. That package delegates to whichever backend +was selected at startup via `internal/observability.Backend`. + +### Why Prometheus-native and OTel are mutually exclusive + +**Exactly one** metrics backend may be active at runtime: + +| Mode | What happens | +|------|-------------| +| `prometheus` | Native Prometheus client registers metrics on a dedicated registry and exposes `/metrics`. No OTel SDK is initialised. | +| `otel` | OTel SDK pushes metrics via OTLP/gRPC or OTLP/HTTP to an external collector. No `/metrics` endpoint is exposed. | +| `none` | A safe noop backend is used. All `Record*` calls are discarded. | + +Running both simultaneously would mean every metric is recorded twice through +two different code paths, with differing semantics (pull vs. push, different +naming rules, different cardinality handling). The design enforces a single +source of truth. + +### Future OTel tracing and logging + +The `internal/observability/otel/` package is designed so that tracing and +logging support can be added **beside** the existing metrics code without +touching the Prometheus-native path: + +```bash +internal/observability/otel/ + backend.go ← metrics + exporter.go ← OTLP exporter creation + resource.go ← OTel resource + trace.go ← future: TracerProvider setup + log.go ← future: LoggerProvider setup +``` + +--- + +## Configuration + +### Config precedence + +1. CLI flags (highest priority) +2. Environment variables +3. Defaults + +### Config struct + +```go +type MetricsConfig struct { + Enabled bool + Backend string // "prometheus" | "otel" | "none" + Prometheus PrometheusConfig + OTel OTelConfig + ServiceName string + ServiceVersion string + DeploymentEnvironment string +} + +type PrometheusConfig struct { + Path string // default: "/metrics" +} + +type OTelConfig struct { + Protocol string // "grpc" (default) or "http" + Endpoint string // default: "localhost:4317" + Insecure bool // default: true + ExportInterval time.Duration // default: 60s +} +``` + +### Environment variables + +| Variable | Default | Description | +|----------|---------|-------------| +| `METRICS_ENABLED` | `true` | Enable/disable metrics | +| `METRICS_BACKEND` | `prometheus` | Backend: `prometheus`, `otel`, or `none` | +| `METRICS_PATH` | `/metrics` | HTTP path for Prometheus endpoint | +| `OTEL_METRICS_PROTOCOL` | `grpc` | OTLP transport: `grpc` or `http` | +| `OTEL_METRICS_ENDPOINT` | `localhost:4317` | OTLP collector address | +| `OTEL_METRICS_INSECURE` | `true` | Disable TLS for OTLP | +| `OTEL_METRICS_EXPORT_INTERVAL` | `60s` | Push interval (e.g. `10s`, `1m`) | +| `DEPLOYMENT_ENVIRONMENT` | _(unset)_ | OTel deployment.environment attribute | + +### CLI flags + +```bash +--metrics-enabled bool (default: true) +--metrics-backend string (default: prometheus) +--metrics-path string (default: /metrics) +--otel-metrics-protocol string (default: grpc) +--otel-metrics-endpoint string (default: localhost:4317) +--otel-metrics-insecure bool (default: true) +--otel-metrics-export-interval duration (default: 1m0s) +``` + +--- + +## When to choose each backend + +| Criterion | Prometheus | OTel/OTLP | +|-----------|-----------|-----------| +| Existing Prometheus/Grafana stack | ✅ | | +| Pull-based scraping | ✅ | | +| No external collector required | ✅ | | +| Vendor-neutral telemetry | | ✅ | +| Push-based export | | ✅ | +| Grafana Cloud / managed OTLP | | ✅ | +| Future traces + logs via same pipeline | | ✅ | + +--- + +## Enabling Prometheus-native mode + +### Environment variables + +```bash +METRICS_ENABLED=true +METRICS_BACKEND=prometheus +METRICS_PATH=/metrics +``` + +### CLI + +```bash +./gerbil --metrics-enabled --metrics-backend=prometheus --metrics-path=/metrics \ + --config=/etc/gerbil/config.json +``` + +The metrics config is supplied separately via env/flags; it is not embedded +in the WireGuard config file. + +The Prometheus `/metrics` endpoint is registered only when +`--metrics-backend=prometheus`. All gerbil_* metrics plus Go runtime metrics +are available. + +--- + +## Enabling OTel mode + +### Environment variables + +```bash +export METRICS_ENABLED=true +export METRICS_BACKEND=otel +export OTEL_METRICS_PROTOCOL=grpc +export OTEL_METRICS_ENDPOINT=otel-collector:4317 +export OTEL_METRICS_INSECURE=true +export OTEL_METRICS_EXPORT_INTERVAL=10s +export DEPLOYMENT_ENVIRONMENT=production +``` + +### CLI + +```bash +./gerbil --metrics-enabled \ + --metrics-backend=otel \ + --otel-metrics-protocol=grpc \ + --otel-metrics-endpoint=otel-collector:4317 \ + --otel-metrics-insecure \ + --otel-metrics-export-interval=10s \ + --config=/etc/gerbil/config.json +``` + +### HTTP mode (OTLP/HTTP) + +```bash +export OTEL_METRICS_PROTOCOL=http +export OTEL_METRICS_ENDPOINT=otel-collector:4318 +``` + +--- + +## Disabling metrics + +```bash +export METRICS_ENABLED=false +# or +./gerbil --metrics-enabled=false +# or +./gerbil --metrics-backend=none +``` + +When disabled, all `Record*` calls are directed to a safe noop backend that +discards observations without allocating or locking. + +--- + +## Metric catalog + +All metrics use the prefix `gerbil__`. + +### WireGuard metrics + +| Metric | Type | Labels | Description | +|--------|------|--------|-------------| +| `gerbil_wg_interface_up` | Gauge | `ifname`, `instance` | 1=up, 0=down | +| `gerbil_wg_peers_total` | UpDownCounter | `ifname` | Configured peers | +| `gerbil_wg_peer_connected` | Gauge | `ifname`, `peer` | 1=connected, 0=disconnected | +| `gerbil_wg_bytes_received_total` | Counter | `ifname`, `peer` | Bytes received | +| `gerbil_wg_bytes_transmitted_total` | Counter | `ifname`, `peer` | Bytes transmitted | +| `gerbil_wg_handshakes_total` | Counter | `ifname`, `peer`, `result` | Handshake attempts | +| `gerbil_wg_handshake_latency_seconds` | Histogram | `ifname`, `peer` | Handshake duration | +| `gerbil_wg_peer_rtt_seconds` | Histogram | `ifname`, `peer` | Peer round-trip time | + +### Relay metrics + +| Metric | Type | Labels | +|--------|------|--------| +| `gerbil_proxy_mapping_active` | UpDownCounter | `ifname` | +| `gerbil_session_active` | UpDownCounter | `ifname` | +| `gerbil_active_sessions` | UpDownCounter | `ifname` | +| `gerbil_udp_packets_total` | Counter | `ifname`, `type`, `direction` | +| `gerbil_hole_punch_events_total` | Counter | `ifname`, `result` | + +### SNI proxy metrics + +| Metric | Type | Labels | +|--------|------|--------| +| `gerbil_sni_connections_total` | Counter | `result` | +| `gerbil_sni_active_connections` | UpDownCounter | _(none)_ | +| `gerbil_sni_route_cache_hits_total` | Counter | `result` | +| `gerbil_sni_route_api_requests_total` | Counter | `result` | +| `gerbil_proxy_route_lookups_total` | Counter | `result`, `hostname` | + +### HTTP metrics + +| Metric | Type | Labels | +|--------|------|--------| +| `gerbil_http_requests_total` | Counter | `endpoint`, `method`, `status_code` | +| `gerbil_http_request_duration_seconds` | Histogram | `endpoint`, `method` | + +--- + +## Using Docker Compose + +The `docker-compose.metrics.yml` provides a complete observability stack. + +**Prometheus mode:** + +```bash +METRICS_BACKEND=prometheus docker-compose -f docker-compose.metrics.yml up -d +# Scrape at http://localhost:3003/metrics +# Grafana at http://localhost:3000 (admin/admin) +``` + +**OTel mode:** + +```bash +METRICS_BACKEND=otel OTEL_METRICS_ENDPOINT=otel-collector:4317 \ + docker-compose -f docker-compose.metrics.yml up -d +``` diff --git a/examples/otel-collector-config.yaml b/examples/otel-collector-config.yaml new file mode 100644 index 0000000..5c85356 --- /dev/null +++ b/examples/otel-collector-config.yaml @@ -0,0 +1,46 @@ +receivers: + otlp: + protocols: + grpc: + endpoint: 0.0.0.0:4317 + http: + endpoint: 0.0.0.0:4318 + +processors: + batch: + timeout: 10s + send_batch_size: 1024 + + # Add resource attributes + resource: + attributes: + - key: service.environment + value: "development" + action: insert + +exporters: + # Prometheus exporter for scraping + prometheus: + endpoint: "0.0.0.0:8889" + namespace: "gerbil" + send_timestamps: true + metric_expiration: 5m + resource_to_telemetry_conversion: + enabled: true + + # Prometheus remote write (optional) + prometheusremotewrite: + endpoint: "http://prometheus:9090/api/v1/write" + tls: + insecure: true + + # Debug exporter for debugging + debug: + verbosity: normal + +service: + pipelines: + metrics: + receivers: [otlp] + processors: [batch, resource] + exporters: [prometheus, prometheusremotewrite, debug] \ No newline at end of file diff --git a/examples/prometheus.yml b/examples/prometheus.yml new file mode 100644 index 0000000..1ca99e9 --- /dev/null +++ b/examples/prometheus.yml @@ -0,0 +1,24 @@ +global: + scrape_interval: 15s + evaluation_interval: 15s + external_labels: + cluster: 'gerbil-dev' + +scrape_configs: + # Scrape Gerbil's /metrics endpoint directly + - job_name: 'gerbil' + static_configs: + - targets: ['gerbil:3003'] + labels: + service: 'gerbil' + environment: 'development' + + # Scrape OpenTelemetry Collector metrics + - job_name: 'otel-collector' + static_configs: + - targets: ['otel-collector:8888'] + labels: + service: 'otel-collector' + - targets: ['otel-collector:8889'] + labels: + service: 'otel-collector-prometheus-exporter' diff --git a/go.mod b/go.mod index fb9debb..ecae7a6 100644 --- a/go.mod +++ b/go.mod @@ -4,20 +4,47 @@ go 1.25 require ( github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/prometheus/client_golang v1.20.5 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.46.0 - golang.org/x/sync v0.1.0 + go.opentelemetry.io/otel v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 + go.opentelemetry.io/otel/metric v1.42.0 + go.opentelemetry.io/otel/sdk v1.42.0 + go.opentelemetry.io/otel/sdk/metric v1.42.0 + golang.org/x/crypto v0.48.0 + golang.org/x/sync v0.19.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 + google.golang.org/grpc v1.79.3 ) require ( - github.com/google/go-cmp v0.5.9 // indirect + github.com/beorn7/perks v1.0.1 // indirect + github.com/cenkalti/backoff/v5 v5.0.3 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/go-cmp v0.7.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/josharian/native v1.1.0 // indirect + github.com/klauspost/compress v1.17.9 // indirect github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.61.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect - golang.org/x/net v0.47.0 // indirect - golang.org/x/sys v0.39.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/otel/trace v1.42.0 // indirect + go.opentelemetry.io/proto/otlp v1.9.0 // indirect + golang.org/x/net v0.51.0 // indirect + golang.org/x/sys v0.41.0 // indirect + golang.org/x/text v0.34.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect + google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index 72c7d68..dd24281 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,30 @@ -github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= -github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= +github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= +github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= @@ -10,23 +33,69 @@ github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/prometheus/client_golang v1.20.5 h1:cxppBPuYhUnsO6yo/aoRol4L7q7UFfdm+bR9r+8l63Y= +github.com/prometheus/client_golang v1.20.5/go.mod h1:PIEt8X02hGcP8JWbeHyeZ53Y/jReSnHgO035n//V5WE= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.61.0 h1:3gv/GThfX0cV2lpO7gkTUwZru38mxevy90Bj8YFSRQQ= +github.com/prometheus/common v0.61.0/go.mod h1:zr29OCN/2BsJRaFwG8QOBr41D6kkchKbpeNH7pAjb/s= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= -golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY= -golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= +go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 h1:MdKucPl/HbzckWWEisiNqMPhRrAOQX8r4jTuGr636gk= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0/go.mod h1:RolT8tWtfHcjajEH5wFIZ4Dgh5jpPdFXYV9pTAk/qjc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 h1:H7O6RlGOMTizyl3R08Kn5pdM06bnH8oscSj7o11tmLA= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0/go.mod h1:mBFWu/WOVDkWWsR7Tx7h6EpQB8wsv7P0Yrh0Pb7othc= +go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= +go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= +go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= +go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= +go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= +go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= +go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= +go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= +go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= +go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= +golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= +golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0= +google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go new file mode 100644 index 0000000..4a92b9f --- /dev/null +++ b/internal/metrics/metrics.go @@ -0,0 +1,506 @@ +// Package metrics provides the application-level metrics facade for Gerbil. +// +// Application code (main, relay, proxy) uses only the Record* functions in this +// package. The actual recording is delegated to the backend selected in +// internal/observability. Neither Prometheus nor OTel packages are imported here. +package metrics + +import ( + "context" + "net/http" + + "github.com/fosrl/gerbil/internal/observability" +) + +// Config is the metrics configuration type. It is an alias for +// observability.MetricsConfig so callers do not need to import observability. +type Config = observability.MetricsConfig + +// PrometheusConfig is re-exported for convenience. +type PrometheusConfig = observability.PrometheusConfig + +// OTelConfig is re-exported for convenience. +type OTelConfig = observability.OTelConfig + +var ( + backend observability.Backend + + // Interface and peer metrics + wgInterfaceUp observability.Int64Gauge + wgPeersTotal observability.UpDownCounter + wgPeerConnected observability.Int64Gauge + wgHandshakesTotal observability.Counter + wgHandshakeLatency observability.Histogram + wgPeerRTT observability.Histogram + wgBytesReceived observability.Counter + wgBytesTransmitted observability.Counter + allowedIPsCount observability.UpDownCounter + keyRotationTotal observability.Counter + + // System and proxy metrics + netlinkEventsTotal observability.Counter + netlinkErrorsTotal observability.Counter + syncDuration observability.Histogram + workqueueDepth observability.UpDownCounter + kernelModuleLoads observability.Counter + firewallRulesApplied observability.Counter + activeSessions observability.UpDownCounter + activeProxyConnections observability.UpDownCounter + proxyRouteLookups observability.Counter + proxyTLSHandshake observability.Histogram + proxyBytesTransmitted observability.Counter + + // UDP Relay / Proxy Metrics + udpPacketsTotal observability.Counter + udpPacketSizeBytes observability.Histogram + holePunchEventsTotal observability.Counter + proxyMappingActive observability.UpDownCounter + sessionActive observability.UpDownCounter + sessionRebuiltTotal observability.Counter + commPatternActive observability.UpDownCounter + proxyCleanupRemovedTotal observability.Counter + proxyConnectionErrorsTotal observability.Counter + proxyInitialMappingsTotal observability.Int64Gauge + proxyMappingUpdatesTotal observability.Counter + proxyIdleCleanupDuration observability.Histogram + + // SNI Proxy Metrics + sniConnectionsTotal observability.Counter + sniConnectionDuration observability.Histogram + sniActiveConnections observability.UpDownCounter + sniRouteCacheHitsTotal observability.Counter + sniRouteAPIRequestsTotal observability.Counter + sniRouteAPILatency observability.Histogram + sniLocalOverrideTotal observability.Counter + sniTrustedProxyEventsTotal observability.Counter + sniProxyProtocolParseErrorsTotal observability.Counter + sniDataBytesTotal observability.Counter + sniTunnelTerminationsTotal observability.Counter + + // HTTP API & Peer Management Metrics + httpRequestsTotal observability.Counter + httpRequestDuration observability.Histogram + peerOperationsTotal observability.Counter + proxyMappingUpdateRequestsTotal observability.Counter + destinationsUpdateRequestsTotal observability.Counter + + // Remote Configuration, Reporting & Housekeeping + remoteConfigFetchesTotal observability.Counter + bandwidthReportsTotal observability.Counter + peerBandwidthBytesTotal observability.Counter + memorySpikeTotal observability.Counter + heapProfilesWrittenTotal observability.Counter + + // Operational metrics + configReloadsTotal observability.Counter + restartTotal observability.Counter + authFailuresTotal observability.Counter + aclDeniedTotal observability.Counter + certificateExpiryDays observability.Float64Gauge +) + +// DefaultConfig returns a default metrics configuration. +func DefaultConfig() Config { + return observability.DefaultMetricsConfig() +} + +// Initialize sets up the metrics system using the selected backend. +// It returns the /metrics HTTP handler (non-nil only for Prometheus backend). +func Initialize(cfg Config) (http.Handler, error) { + b, err := observability.New(cfg) + if err != nil { + return nil, err + } + backend = b + + if err := createInstruments(); err != nil { + return nil, err + } + + return backend.HTTPHandler(), nil +} + +// Shutdown gracefully shuts down the metrics backend. +func Shutdown(ctx context.Context) error { + if backend != nil { + return backend.Shutdown(ctx) + } + return nil +} + +func createInstruments() error { + durationBuckets := []float64{0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1, 2.5, 5, 10, 30} + sizeBuckets := []float64{512, 1024, 4096, 16384, 65536, 262144, 1048576} + sniDurationBuckets := []float64{0.1, 0.5, 1, 2.5, 5, 10, 30, 60, 120} + + b := backend + + wgInterfaceUp = b.NewInt64Gauge("gerbil_wg_interface_up", + "Operational state of a WireGuard interface (1=up, 0=down)", "ifname", "instance") + wgPeersTotal = b.NewUpDownCounter("gerbil_wg_peers_total", + "Total number of configured peers per interface", "ifname") + wgPeerConnected = b.NewInt64Gauge("gerbil_wg_peer_connected", + "Whether a specific peer is connected (1=connected, 0=disconnected)", "ifname", "peer") + allowedIPsCount = b.NewUpDownCounter("gerbil_allowed_ips_count", + "Number of allowed IPs configured per peer", "ifname", "peer") + keyRotationTotal = b.NewCounter("gerbil_key_rotation_total", + "Key rotation events", "ifname", "reason") + wgHandshakesTotal = b.NewCounter("gerbil_wg_handshakes_total", + "Count of handshake attempts with their result status", "ifname", "peer", "result") + wgHandshakeLatency = b.NewHistogram("gerbil_wg_handshake_latency_seconds", + "Distribution of handshake latencies in seconds", durationBuckets, "ifname", "peer") + wgPeerRTT = b.NewHistogram("gerbil_wg_peer_rtt_seconds", + "Observed round-trip time to a peer in seconds", durationBuckets, "ifname", "peer") + wgBytesReceived = b.NewCounter("gerbil_wg_bytes_received_total", + "Number of bytes received from a peer", "ifname", "peer") + wgBytesTransmitted = b.NewCounter("gerbil_wg_bytes_transmitted_total", + "Number of bytes transmitted to a peer", "ifname", "peer") + netlinkEventsTotal = b.NewCounter("gerbil_netlink_events_total", + "Number of netlink events processed", "event_type") + netlinkErrorsTotal = b.NewCounter("gerbil_netlink_errors_total", + "Count of netlink or kernel errors", "component", "error_type") + syncDuration = b.NewHistogram("gerbil_sync_duration_seconds", + "Duration of reconciliation/sync loops in seconds", durationBuckets, "component") + workqueueDepth = b.NewUpDownCounter("gerbil_workqueue_depth", + "Current length of internal work queues", "queue") + kernelModuleLoads = b.NewCounter("gerbil_kernel_module_loads_total", + "Count of kernel module load attempts", "result") + firewallRulesApplied = b.NewCounter("gerbil_firewall_rules_applied_total", + "IPTables/NFT rules applied", "result", "chain") + activeSessions = b.NewUpDownCounter("gerbil_active_sessions", + "Number of active UDP relay sessions", "ifname") + activeProxyConnections = b.NewUpDownCounter("gerbil_active_proxy_connections", + "Active SNI proxy connections") + proxyRouteLookups = b.NewCounter("gerbil_proxy_route_lookups_total", + "Number of route lookups", "result") + proxyTLSHandshake = b.NewHistogram("gerbil_proxy_tls_handshake_seconds", + "TLS handshake duration for SNI proxy in seconds", durationBuckets) + proxyBytesTransmitted = b.NewCounter("gerbil_proxy_bytes_transmitted_total", + "Bytes sent/received by the SNI proxy", "direction") + configReloadsTotal = b.NewCounter("gerbil_config_reloads_total", + "Number of configuration reloads", "result") + restartTotal = b.NewCounter("gerbil_restart_total", + "Process restart count") + authFailuresTotal = b.NewCounter("gerbil_auth_failures_total", + "Count of authentication or peer validation failures", "peer", "reason") + aclDeniedTotal = b.NewCounter("gerbil_acl_denied_total", + "Access control denied events", "ifname", "peer", "policy") + certificateExpiryDays = b.NewFloat64Gauge("gerbil_certificate_expiry_days", + "Days until certificate expiry", "cert_name", "ifname") + udpPacketsTotal = b.NewCounter("gerbil_udp_packets_total", + "Count of UDP packets processed by relay workers", "ifname", "type", "direction") + udpPacketSizeBytes = b.NewHistogram("gerbil_udp_packet_size_bytes", + "Size distribution of packets forwarded through relay", sizeBuckets, "ifname", "type") + holePunchEventsTotal = b.NewCounter("gerbil_hole_punch_events_total", + "Count of hole punch messages processed", "ifname", "result") + proxyMappingActive = b.NewUpDownCounter("gerbil_proxy_mapping_active", + "Number of active proxy mappings", "ifname") + sessionActive = b.NewUpDownCounter("gerbil_session_active", + "Number of active WireGuard sessions", "ifname") + sessionRebuiltTotal = b.NewCounter("gerbil_session_rebuilt_total", + "Count of sessions rebuilt from communication patterns", "ifname") + commPatternActive = b.NewUpDownCounter("gerbil_comm_pattern_active", + "Number of active communication patterns", "ifname") + proxyCleanupRemovedTotal = b.NewCounter("gerbil_proxy_cleanup_removed_total", + "Count of items removed during cleanup routines", "ifname", "component") + proxyConnectionErrorsTotal = b.NewCounter("gerbil_proxy_connection_errors_total", + "Count of connection errors in proxy operations", "ifname", "error_type") + proxyInitialMappingsTotal = b.NewInt64Gauge("gerbil_proxy_initial_mappings", + "Number of initial proxy mappings loaded", "ifname") + proxyMappingUpdatesTotal = b.NewCounter("gerbil_proxy_mapping_updates_total", + "Count of proxy mapping updates", "ifname") + proxyIdleCleanupDuration = b.NewHistogram("gerbil_proxy_idle_cleanup_duration_seconds", + "Duration of cleanup cycles", durationBuckets, "ifname", "component") + sniConnectionsTotal = b.NewCounter("gerbil_sni_connections_total", + "Count of connections processed by SNI proxy", "result") + sniConnectionDuration = b.NewHistogram("gerbil_sni_connection_duration_seconds", + "Lifetime distribution of proxied TLS connections", sniDurationBuckets) + sniActiveConnections = b.NewUpDownCounter("gerbil_sni_active_connections", + "Number of active SNI tunnels") + sniRouteCacheHitsTotal = b.NewCounter("gerbil_sni_route_cache_hits_total", + "Count of route cache hits and misses", "result") + sniRouteAPIRequestsTotal = b.NewCounter("gerbil_sni_route_api_requests_total", + "Count of route API requests", "result") + sniRouteAPILatency = b.NewHistogram("gerbil_sni_route_api_latency_seconds", + "Distribution of route API call latencies", durationBuckets) + sniLocalOverrideTotal = b.NewCounter("gerbil_sni_local_override_total", + "Count of routes using local overrides", "hit") + sniTrustedProxyEventsTotal = b.NewCounter("gerbil_sni_trusted_proxy_events_total", + "Count of PROXY protocol events", "event") + sniProxyProtocolParseErrorsTotal = b.NewCounter("gerbil_sni_proxy_protocol_parse_errors_total", + "Count of PROXY protocol parse failures") + sniDataBytesTotal = b.NewCounter("gerbil_sni_data_bytes_total", + "Count of bytes proxied through SNI tunnels", "direction") + sniTunnelTerminationsTotal = b.NewCounter("gerbil_sni_tunnel_terminations_total", + "Count of tunnel terminations by reason", "reason") + httpRequestsTotal = b.NewCounter("gerbil_http_requests_total", + "Count of HTTP requests to management API", "endpoint", "method", "status_code") + httpRequestDuration = b.NewHistogram("gerbil_http_request_duration_seconds", + "Distribution of HTTP request handling time", durationBuckets, "endpoint", "method") + peerOperationsTotal = b.NewCounter("gerbil_peer_operations_total", + "Count of peer lifecycle operations", "operation", "result") + proxyMappingUpdateRequestsTotal = b.NewCounter("gerbil_proxy_mapping_update_requests_total", + "Count of proxy mapping update API calls", "result") + destinationsUpdateRequestsTotal = b.NewCounter("gerbil_destinations_update_requests_total", + "Count of destinations update API calls", "result") + remoteConfigFetchesTotal = b.NewCounter("gerbil_remote_config_fetches_total", + "Count of remote configuration fetch attempts", "result") + bandwidthReportsTotal = b.NewCounter("gerbil_bandwidth_reports_total", + "Count of bandwidth report transmissions", "result") + peerBandwidthBytesTotal = b.NewCounter("gerbil_peer_bandwidth_bytes_total", + "Bytes per peer tracked by bandwidth calculation", "peer", "direction") + memorySpikeTotal = b.NewCounter("gerbil_memory_spike_total", + "Count of memory spikes detected", "severity") + heapProfilesWrittenTotal = b.NewCounter("gerbil_heap_profiles_written_total", + "Count of heap profile files generated") + + return nil +} + +func RecordInterfaceUp(ifname, instance string, up bool) { + value := int64(0) + if up { + value = 1 + } + wgInterfaceUp.Record(context.Background(), value, observability.Labels{"ifname": ifname, "instance": instance}) +} + +func RecordPeersTotal(ifname string, delta int64) { + wgPeersTotal.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) +} + +func RecordPeerConnected(ifname, peer string, connected bool) { + value := int64(0) + if connected { + value = 1 + } + wgPeerConnected.Record(context.Background(), value, observability.Labels{"ifname": ifname, "peer": peer}) +} + +func RecordHandshake(ifname, peer, result string) { + wgHandshakesTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "peer": peer, "result": result}) +} + +func RecordHandshakeLatency(ifname, peer string, seconds float64) { + wgHandshakeLatency.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "peer": peer}) +} + +func RecordPeerRTT(ifname, peer string, seconds float64) { + wgPeerRTT.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "peer": peer}) +} + +func RecordBytesReceived(ifname, peer string, bytes int64) { + wgBytesReceived.Add(context.Background(), bytes, observability.Labels{"ifname": ifname, "peer": peer}) +} + +func RecordBytesTransmitted(ifname, peer string, bytes int64) { + wgBytesTransmitted.Add(context.Background(), bytes, observability.Labels{"ifname": ifname, "peer": peer}) +} + +func RecordAllowedIPsCount(ifname, peer string, delta int64) { + allowedIPsCount.Add(context.Background(), delta, observability.Labels{"ifname": ifname, "peer": peer}) +} + +func RecordKeyRotation(ifname, reason string) { + keyRotationTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "reason": reason}) +} + +func RecordNetlinkEvent(eventType string) { + netlinkEventsTotal.Add(context.Background(), 1, observability.Labels{"event_type": eventType}) +} + +func RecordNetlinkError(component, errorType string) { + netlinkErrorsTotal.Add(context.Background(), 1, observability.Labels{"component": component, "error_type": errorType}) +} + +func RecordSyncDuration(component string, seconds float64) { + syncDuration.Record(context.Background(), seconds, observability.Labels{"component": component}) +} + +func RecordWorkqueueDepth(queue string, delta int64) { + workqueueDepth.Add(context.Background(), delta, observability.Labels{"queue": queue}) +} + +func RecordKernelModuleLoad(result string) { + kernelModuleLoads.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordFirewallRuleApplied(result, chain string) { + firewallRulesApplied.Add(context.Background(), 1, observability.Labels{"result": result, "chain": chain}) +} + +func RecordActiveSession(ifname string, delta int64) { + activeSessions.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) +} + +func RecordActiveProxyConnection(hostname string, delta int64) { + _ = hostname + activeProxyConnections.Add(context.Background(), delta, nil) +} + +func RecordProxyRouteLookup(result, hostname string) { + _ = hostname + proxyRouteLookups.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordProxyTLSHandshake(hostname string, seconds float64) { + _ = hostname + proxyTLSHandshake.Record(context.Background(), seconds, nil) +} + +func RecordProxyBytesTransmitted(hostname, direction string, bytes int64) { + _ = hostname + proxyBytesTransmitted.Add(context.Background(), bytes, observability.Labels{"direction": direction}) +} + +func RecordConfigReload(result string) { + configReloadsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordRestart() { + restartTotal.Add(context.Background(), 1, nil) +} + +func RecordAuthFailure(peer, reason string) { + authFailuresTotal.Add(context.Background(), 1, observability.Labels{"peer": peer, "reason": reason}) +} + +func RecordACLDenied(ifname, peer, policy string) { + aclDeniedTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "peer": peer, "policy": policy}) +} + +func RecordCertificateExpiry(certName, ifname string, days float64) { + certificateExpiryDays.Record(context.Background(), days, observability.Labels{"cert_name": certName, "ifname": ifname}) +} + +func RecordUDPPacket(ifname, packetType, direction string) { + udpPacketsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "type": packetType, "direction": direction}) +} + +func RecordUDPPacketSize(ifname, packetType string, bytes float64) { + udpPacketSizeBytes.Record(context.Background(), bytes, observability.Labels{"ifname": ifname, "type": packetType}) +} + +func RecordHolePunchEvent(ifname, result string) { + holePunchEventsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "result": result}) +} + +func RecordProxyMapping(ifname string, delta int64) { + proxyMappingActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) +} + +func RecordSession(ifname string, delta int64) { + sessionActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) +} + +func RecordSessionRebuilt(ifname string) { + sessionRebuiltTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname}) +} + +func RecordCommPattern(ifname string, delta int64) { + commPatternActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) +} + +func RecordProxyCleanupRemoved(ifname, component string, count int64) { + proxyCleanupRemovedTotal.Add(context.Background(), count, observability.Labels{"ifname": ifname, "component": component}) +} + +func RecordProxyConnectionError(ifname, errorType string) { + proxyConnectionErrorsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "error_type": errorType}) +} + +func RecordProxyInitialMappings(ifname string, count int64) { + proxyInitialMappingsTotal.Record(context.Background(), count, observability.Labels{"ifname": ifname}) +} + +func RecordProxyMappingUpdate(ifname string) { + proxyMappingUpdatesTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname}) +} + +func RecordProxyIdleCleanupDuration(ifname, component string, seconds float64) { + proxyIdleCleanupDuration.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "component": component}) +} + +func RecordSNIConnection(result string) { + sniConnectionsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordSNIConnectionDuration(seconds float64) { + sniConnectionDuration.Record(context.Background(), seconds, nil) +} + +func RecordSNIActiveConnection(delta int64) { + sniActiveConnections.Add(context.Background(), delta, nil) +} + +func RecordSNIRouteCacheHit(result string) { + sniRouteCacheHitsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordSNIRouteAPIRequest(result string) { + sniRouteAPIRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordSNIRouteAPILatency(seconds float64) { + sniRouteAPILatency.Record(context.Background(), seconds, nil) +} + +func RecordSNILocalOverride(hit string) { + sniLocalOverrideTotal.Add(context.Background(), 1, observability.Labels{"hit": hit}) +} + +func RecordSNITrustedProxyEvent(event string) { + sniTrustedProxyEventsTotal.Add(context.Background(), 1, observability.Labels{"event": event}) +} + +func RecordSNIProxyProtocolParseError() { + sniProxyProtocolParseErrorsTotal.Add(context.Background(), 1, nil) +} + +func RecordSNIDataBytes(direction string, bytes int64) { + sniDataBytesTotal.Add(context.Background(), bytes, observability.Labels{"direction": direction}) +} + +func RecordSNITunnelTermination(reason string) { + sniTunnelTerminationsTotal.Add(context.Background(), 1, observability.Labels{"reason": reason}) +} + +func RecordHTTPRequest(endpoint, method, statusCode string) { + httpRequestsTotal.Add(context.Background(), 1, observability.Labels{"endpoint": endpoint, "method": method, "status_code": statusCode}) +} + +func RecordHTTPRequestDuration(endpoint, method string, seconds float64) { + httpRequestDuration.Record(context.Background(), seconds, observability.Labels{"endpoint": endpoint, "method": method}) +} + +func RecordPeerOperation(operation, result string) { + peerOperationsTotal.Add(context.Background(), 1, observability.Labels{"operation": operation, "result": result}) +} + +func RecordProxyMappingUpdateRequest(result string) { + proxyMappingUpdateRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordDestinationsUpdateRequest(result string) { + destinationsUpdateRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordRemoteConfigFetch(result string) { + remoteConfigFetchesTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordBandwidthReport(result string) { + bandwidthReportsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) +} + +func RecordPeerBandwidthBytes(peer, direction string, bytes int64) { + peerBandwidthBytesTotal.Add(context.Background(), bytes, observability.Labels{"peer": peer, "direction": direction}) +} + +func RecordMemorySpike(severity string) { + memorySpikeTotal.Add(context.Background(), 1, observability.Labels{"severity": severity}) +} + +func RecordHeapProfileWritten() { + heapProfilesWrittenTotal.Add(context.Background(), 1, nil) +} diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go new file mode 100644 index 0000000..132c3fe --- /dev/null +++ b/internal/metrics/metrics_test.go @@ -0,0 +1,258 @@ +package metrics_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/fosrl/gerbil/internal/metrics" + "github.com/fosrl/gerbil/internal/observability" +) + +const exampleHostname = "example.com" + +func initPrometheus(t *testing.T) http.Handler { + t.Helper() + cfg := metrics.DefaultConfig() + cfg.Enabled = true + cfg.Backend = "prometheus" + cfg.Prometheus.Path = "/metrics" + + h, err := metrics.Initialize(cfg) + if err != nil { + t.Fatalf("Initialize failed: %v", err) + } + t.Cleanup(func() { + metrics.Shutdown(context.Background()) //nolint:errcheck + }) + return h +} + +func initNoop(t *testing.T) { + t.Helper() + cfg := metrics.DefaultConfig() + cfg.Enabled = false + _, err := metrics.Initialize(cfg) + if err != nil { + t.Fatalf("Initialize noop failed: %v", err) + } + t.Cleanup(func() { + metrics.Shutdown(context.Background()) //nolint:errcheck + }) +} + +func scrape(t *testing.T, h http.Handler) string { + t.Helper() + req := httptest.NewRequest(http.MethodGet, "/metrics", http.NoBody) + rr := httptest.NewRecorder() + h.ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("scrape returned %d", rr.Code) + } + b, _ := io.ReadAll(rr.Body) + return string(b) +} + +func assertContains(t *testing.T, body, substr string) { + t.Helper() + if !strings.Contains(body, substr) { + t.Errorf("expected %q in output\nbody:\n%s", substr, body) + } +} + +// --- Tests --- + +func TestInitializePrometheus(t *testing.T) { + h := initPrometheus(t) + if h == nil { + t.Error("expected non-nil HTTP handler for prometheus backend") + } +} + +func TestInitializeNoop(t *testing.T) { + initNoop(t) + // All Record* functions must not panic when noop backend is active. + metrics.RecordRestart() + metrics.RecordHTTPRequest("/test", "GET", "200") + metrics.RecordSNIConnection("accepted") + metrics.RecordPeersTotal("wg0", 1) +} + +func TestDefaultConfig(t *testing.T) { + cfg := metrics.DefaultConfig() + if cfg.Backend != "prometheus" { + t.Errorf("expected prometheus default backend, got %q", cfg.Backend) + } +} + +func TestShutdownNoInit(t *testing.T) { + // Shutdown without Initialize should not panic or error. + if err := metrics.Shutdown(context.Background()); err != nil { + t.Errorf("unexpected error: %v", err) + } +} + +func TestRecordHTTPRequest(t *testing.T) { + h := initPrometheus(t) + metrics.RecordHTTPRequest("/peers", "POST", "201") + body := scrape(t, h) + assertContains(t, body, "gerbil_http_requests_total") +} + +func TestRecordHTTPRequestDuration(t *testing.T) { + h := initPrometheus(t) + metrics.RecordHTTPRequestDuration("/peers", "POST", 0.05) + body := scrape(t, h) + assertContains(t, body, "gerbil_http_request_duration_seconds") +} + +func TestRecordInterfaceUp(t *testing.T) { + h := initPrometheus(t) + metrics.RecordInterfaceUp("wg0", "host1", true) + metrics.RecordInterfaceUp("wg0", "host1", false) + body := scrape(t, h) + assertContains(t, body, "gerbil_wg_interface_up") +} + +func TestRecordPeersTotal(t *testing.T) { + h := initPrometheus(t) + metrics.RecordPeersTotal("wg0", 3) + body := scrape(t, h) + assertContains(t, body, "gerbil_wg_peers_total") +} + +func TestRecordBytesReceivedTransmitted(t *testing.T) { + h := initPrometheus(t) + metrics.RecordBytesReceived("wg0", "peer1", 1024) + metrics.RecordBytesTransmitted("wg0", "peer1", 512) + body := scrape(t, h) + assertContains(t, body, "gerbil_wg_bytes_received_total") + assertContains(t, body, "gerbil_wg_bytes_transmitted_total") +} + +func TestRecordSNI(t *testing.T) { + h := initPrometheus(t) + metrics.RecordSNIConnection("accepted") + metrics.RecordSNIActiveConnection(1) + metrics.RecordSNIConnectionDuration(1.5) + metrics.RecordSNIRouteCacheHit("hit") + metrics.RecordSNIRouteAPIRequest("success") + metrics.RecordSNIRouteAPILatency(0.01) + metrics.RecordSNILocalOverride("yes") + metrics.RecordSNITrustedProxyEvent("proxy_protocol_parsed") + metrics.RecordSNIProxyProtocolParseError() + metrics.RecordSNIDataBytes("client_to_target", 2048) + metrics.RecordSNITunnelTermination("eof") + body := scrape(t, h) + assertContains(t, body, "gerbil_sni_connections_total") + assertContains(t, body, "gerbil_sni_active_connections") +} + +func TestRecordRelay(t *testing.T) { + h := initPrometheus(t) + metrics.RecordUDPPacket("relay", "data", "in") + metrics.RecordUDPPacketSize("relay", "data", 256) + metrics.RecordHolePunchEvent("relay", "success") + metrics.RecordProxyMapping("relay", 1) + metrics.RecordSession("relay", 1) + metrics.RecordSessionRebuilt("relay") + metrics.RecordCommPattern("relay", 1) + metrics.RecordProxyCleanupRemoved("relay", "session", 2) + metrics.RecordProxyConnectionError("relay", "dial_udp") + metrics.RecordProxyInitialMappings("relay", 5) + metrics.RecordProxyMappingUpdate("relay") + metrics.RecordProxyIdleCleanupDuration("relay", "conn", 0.1) + body := scrape(t, h) + assertContains(t, body, "gerbil_udp_packets_total") + assertContains(t, body, "gerbil_proxy_mapping_active") +} + +func TestRecordWireGuard(t *testing.T) { + h := initPrometheus(t) + metrics.RecordHandshake("wg0", "peer1", "success") + metrics.RecordHandshakeLatency("wg0", "peer1", 0.02) + metrics.RecordPeerRTT("wg0", "peer1", 0.005) + metrics.RecordPeerConnected("wg0", "peer1", true) + metrics.RecordAllowedIPsCount("wg0", "peer1", 2) + metrics.RecordKeyRotation("wg0", "scheduled") + body := scrape(t, h) + assertContains(t, body, "gerbil_wg_handshakes_total") + assertContains(t, body, "gerbil_wg_peer_connected") +} + +func TestRecordHousekeeping(t *testing.T) { + h := initPrometheus(t) + metrics.RecordRemoteConfigFetch("success") + metrics.RecordBandwidthReport("success") + metrics.RecordPeerBandwidthBytes("peer1", "rx", 512) + metrics.RecordMemorySpike("warning") + metrics.RecordHeapProfileWritten() + body := scrape(t, h) + assertContains(t, body, "gerbil_remote_config_fetches_total") + assertContains(t, body, "gerbil_memory_spike_total") +} + +func TestRecordOperational(t *testing.T) { + h := initPrometheus(t) + metrics.RecordConfigReload("success") + metrics.RecordRestart() + metrics.RecordAuthFailure("peer1", "bad_key") + metrics.RecordACLDenied("wg0", "peer1", "default-deny") + metrics.RecordCertificateExpiry(exampleHostname, "wg0", 90.0) + body := scrape(t, h) + assertContains(t, body, "gerbil_config_reloads_total") + assertContains(t, body, "gerbil_restart_total") +} + +func TestRecordNetlink(t *testing.T) { + h := initPrometheus(t) + metrics.RecordNetlinkEvent("link_up") + metrics.RecordNetlinkError("wg", "timeout") + metrics.RecordSyncDuration("config", 0.1) + metrics.RecordWorkqueueDepth("main", 3) + metrics.RecordKernelModuleLoad("success") + metrics.RecordFirewallRuleApplied("success", "INPUT") + metrics.RecordActiveSession("wg0", 1) + metrics.RecordActiveProxyConnection(exampleHostname, 1) + metrics.RecordProxyRouteLookup("hit", exampleHostname) + metrics.RecordProxyTLSHandshake(exampleHostname, 0.05) + metrics.RecordProxyBytesTransmitted(exampleHostname, "tx", 1024) + body := scrape(t, h) + assertContains(t, body, "gerbil_netlink_events_total") + assertContains(t, body, "gerbil_active_sessions") +} + +func TestRecordPeerOperation(t *testing.T) { + h := initPrometheus(t) + metrics.RecordPeerOperation("add", "success") + metrics.RecordProxyMappingUpdateRequest("success") + metrics.RecordDestinationsUpdateRequest("success") + body := scrape(t, h) + assertContains(t, body, "gerbil_peer_operations_total") +} + +func TestInitializeInvalidBackend(t *testing.T) { + cfg := observability.MetricsConfig{Enabled: true, Backend: "invalid"} + _, err := metrics.Initialize(cfg) + if err == nil { + t.Error("expected error for invalid backend") + } +} + +func TestInitializeBackendNone(t *testing.T) { + cfg := metrics.DefaultConfig() + cfg.Backend = "none" + h, err := metrics.Initialize(cfg) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h != nil { + t.Error("none backend should return nil handler") + } + // All Record* calls should be noop + metrics.RecordRestart() + metrics.Shutdown(context.Background()) //nolint:errcheck +} diff --git a/internal/observability/config.go b/internal/observability/config.go new file mode 100644 index 0000000..9643727 --- /dev/null +++ b/internal/observability/config.go @@ -0,0 +1,119 @@ +// Package observability provides a backend-neutral metrics abstraction for Gerbil. +// +// Exactly one metrics backend may be enabled at runtime: +// - "prometheus" – native Prometheus client; exposes /metrics (no OTel SDK required) +// - "otel" – OpenTelemetry metrics pushed via OTLP (gRPC or HTTP) +// - "none" – metrics disabled; a safe noop implementation is used +// +// Future OTel tracing and logging can be added to this package alongside the +// existing otel sub-package without touching the Prometheus-native path. +package observability + +import ( + "fmt" + "time" +) + +// MetricsConfig is the top-level metrics configuration. +type MetricsConfig struct { + // Enabled controls whether any metrics backend is started. + // When false the noop backend is used regardless of Backend. + Enabled bool + + // Backend selects the active backend: "prometheus", "otel", or "none". + Backend string + + // Prometheus holds settings used only by the Prometheus-native backend. + Prometheus PrometheusConfig + + // OTel holds settings used only by the OTel backend. + OTel OTelConfig + + // ServiceName is propagated to OTel resource attributes. + ServiceName string + + // ServiceVersion is propagated to OTel resource attributes. + ServiceVersion string + + // DeploymentEnvironment is an optional OTel resource attribute. + DeploymentEnvironment string +} + +// PrometheusConfig holds Prometheus-native backend settings. +type PrometheusConfig struct { + // Path is the HTTP path to expose the /metrics endpoint. + // Defaults to "/metrics". + Path string +} + +// OTelConfig holds OpenTelemetry backend settings. +type OTelConfig struct { + // Protocol is the OTLP transport: "grpc" (default) or "http". + Protocol string + + // Endpoint is the OTLP collector address (e.g. "localhost:4317"). + Endpoint string + + // Insecure disables TLS for the OTLP connection. + Insecure bool + + // ExportInterval is how often metrics are pushed to the collector. + // Defaults to 60 s. + ExportInterval time.Duration +} + +// DefaultMetricsConfig returns a MetricsConfig with sensible defaults. +func DefaultMetricsConfig() MetricsConfig { + return MetricsConfig{ + Enabled: true, + Backend: "prometheus", + Prometheus: PrometheusConfig{ + Path: "/metrics", + }, + OTel: OTelConfig{ + Protocol: "grpc", + Endpoint: "localhost:4317", + Insecure: true, + ExportInterval: 60 * time.Second, + }, + ServiceName: "gerbil", + ServiceVersion: "1.0.0", + } +} + +// Validate checks the configuration for logical errors. +func (c *MetricsConfig) Validate() error { + if !c.Enabled { + return nil + } + + switch c.Backend { + case "prometheus", "none", "": + // valid + case "otel": + if c.OTel.Endpoint == "" { + return fmt.Errorf("metrics: backend=otel requires a non-empty OTel endpoint") + } + if c.OTel.Protocol != "grpc" && c.OTel.Protocol != "http" { + return fmt.Errorf("metrics: otel protocol must be \"grpc\" or \"http\", got %q", c.OTel.Protocol) + } + if c.OTel.ExportInterval <= 0 { + return fmt.Errorf("metrics: otel export interval must be positive") + } + default: + return fmt.Errorf("metrics: unknown backend %q (must be \"prometheus\", \"otel\", or \"none\")", c.Backend) + } + + return nil +} + +// effectiveBackend resolves the backend string, treating "" and "none" as noop. +func (c *MetricsConfig) effectiveBackend() string { + if !c.Enabled { + return "none" + } + if c.Backend == "" { + return "none" + } + return c.Backend +} diff --git a/internal/observability/metrics.go b/internal/observability/metrics.go new file mode 100644 index 0000000..ff2e2ae --- /dev/null +++ b/internal/observability/metrics.go @@ -0,0 +1,152 @@ +package observability + +import ( + "context" + "fmt" + "net/http" + + obsotel "github.com/fosrl/gerbil/internal/observability/otel" + obsprom "github.com/fosrl/gerbil/internal/observability/prometheus" +) + +// Labels is a set of key-value pairs attached to a metric observation. +// Use only stable, bounded-cardinality label values. +type Labels = map[string]string + +// Counter is a monotonically increasing instrument. +type Counter interface { + Add(ctx context.Context, value int64, labels Labels) +} + +// UpDownCounter is a bidirectional integer instrument (can go up or down). +type UpDownCounter interface { + Add(ctx context.Context, value int64, labels Labels) +} + +// Int64Gauge records a snapshot integer value. +type Int64Gauge interface { + Record(ctx context.Context, value int64, labels Labels) +} + +// Float64Gauge records a snapshot float value. +type Float64Gauge interface { + Record(ctx context.Context, value float64, labels Labels) +} + +// Histogram records a distribution of values. +type Histogram interface { + Record(ctx context.Context, value float64, labels Labels) +} + +// Backend is the single interface that each metrics implementation must satisfy. +// Application code must not import backend-specific packages (prometheus, otel). +type Backend interface { + // NewCounter creates a counter metric. + // labelNames declares the set of label keys that will be passed at observation time. + NewCounter(name, desc string, labelNames ...string) Counter + + // NewUpDownCounter creates an up-down counter metric. + NewUpDownCounter(name, desc string, labelNames ...string) UpDownCounter + + // NewInt64Gauge creates an integer gauge metric. + NewInt64Gauge(name, desc string, labelNames ...string) Int64Gauge + + // NewFloat64Gauge creates a float gauge metric. + NewFloat64Gauge(name, desc string, labelNames ...string) Float64Gauge + + // NewHistogram creates a histogram metric. + // buckets are the explicit upper-bound bucket boundaries. + NewHistogram(name, desc string, buckets []float64, labelNames ...string) Histogram + + // HTTPHandler returns the /metrics HTTP handler. + // Implementations that do not expose an HTTP endpoint return nil. + HTTPHandler() http.Handler + + // Shutdown performs a graceful flush / shutdown of the backend. + Shutdown(ctx context.Context) error +} + +// New creates the backend selected by cfg and returns it. +// Exactly one backend is created; the selection is mutually exclusive. +func New(cfg MetricsConfig) (Backend, error) { + if err := cfg.Validate(); err != nil { + return nil, err + } + + switch cfg.effectiveBackend() { + case "prometheus": + b, err := obsprom.New(obsprom.Config{ + Path: cfg.Prometheus.Path, + }) + if err != nil { + return nil, err + } + return &promAdapter{b: b}, nil + case "otel": + b, err := obsotel.New(obsotel.Config{ + Protocol: cfg.OTel.Protocol, + Endpoint: cfg.OTel.Endpoint, + Insecure: cfg.OTel.Insecure, + ExportInterval: cfg.OTel.ExportInterval, + ServiceName: cfg.ServiceName, + ServiceVersion: cfg.ServiceVersion, + DeploymentEnvironment: cfg.DeploymentEnvironment, + }) + if err != nil { + return nil, err + } + return &otelAdapter{b: b}, nil + case "none": + return &NoopBackend{}, nil + default: + return nil, fmt.Errorf("observability: unknown backend %q", cfg.effectiveBackend()) + } +} + +// promAdapter wraps obsprom.Backend to implement the observability.Backend interface. +// The concrete instrument types from the prometheus sub-package satisfy the instrument +// interfaces via Go's structural (duck) typing without importing this package. +type promAdapter struct { + b *obsprom.Backend +} + +func (a *promAdapter) NewCounter(name, desc string, labelNames ...string) Counter { + return a.b.NewCounter(name, desc, labelNames...) +} +func (a *promAdapter) NewUpDownCounter(name, desc string, labelNames ...string) UpDownCounter { + return a.b.NewUpDownCounter(name, desc, labelNames...) +} +func (a *promAdapter) NewInt64Gauge(name, desc string, labelNames ...string) Int64Gauge { + return a.b.NewInt64Gauge(name, desc, labelNames...) +} +func (a *promAdapter) NewFloat64Gauge(name, desc string, labelNames ...string) Float64Gauge { + return a.b.NewFloat64Gauge(name, desc, labelNames...) +} +func (a *promAdapter) NewHistogram(name, desc string, buckets []float64, labelNames ...string) Histogram { + return a.b.NewHistogram(name, desc, buckets, labelNames...) +} +func (a *promAdapter) HTTPHandler() http.Handler { return a.b.HTTPHandler() } +func (a *promAdapter) Shutdown(ctx context.Context) error { return a.b.Shutdown(ctx) } + +// otelAdapter wraps obsotel.Backend to implement the observability.Backend interface. +type otelAdapter struct { + b *obsotel.Backend +} + +func (a *otelAdapter) NewCounter(name, desc string, labelNames ...string) Counter { + return a.b.NewCounter(name, desc, labelNames...) +} +func (a *otelAdapter) NewUpDownCounter(name, desc string, labelNames ...string) UpDownCounter { + return a.b.NewUpDownCounter(name, desc, labelNames...) +} +func (a *otelAdapter) NewInt64Gauge(name, desc string, labelNames ...string) Int64Gauge { + return a.b.NewInt64Gauge(name, desc, labelNames...) +} +func (a *otelAdapter) NewFloat64Gauge(name, desc string, labelNames ...string) Float64Gauge { + return a.b.NewFloat64Gauge(name, desc, labelNames...) +} +func (a *otelAdapter) NewHistogram(name, desc string, buckets []float64, labelNames ...string) Histogram { + return a.b.NewHistogram(name, desc, buckets, labelNames...) +} +func (a *otelAdapter) HTTPHandler() http.Handler { return a.b.HTTPHandler() } +func (a *otelAdapter) Shutdown(ctx context.Context) error { return a.b.Shutdown(ctx) } diff --git a/internal/observability/metrics_test.go b/internal/observability/metrics_test.go new file mode 100644 index 0000000..91a048c --- /dev/null +++ b/internal/observability/metrics_test.go @@ -0,0 +1,198 @@ +package observability_test + +import ( + "context" + "testing" + "time" + + "github.com/fosrl/gerbil/internal/observability" +) + +const ( + defaultMetricsPath = "/metrics" + otelGRPCEndpoint = "localhost:4317" + errUnexpectedFmt = "unexpected error: %v" +) + +func TestDefaultMetricsConfig(t *testing.T) { + cfg := observability.DefaultMetricsConfig() + if !cfg.Enabled { + t.Error("default config should have Enabled=true") + } + if cfg.Backend != "prometheus" { + t.Errorf("default backend should be prometheus, got %q", cfg.Backend) + } + if cfg.Prometheus.Path != defaultMetricsPath { + t.Errorf("default prometheus path should be %s, got %q", defaultMetricsPath, cfg.Prometheus.Path) + } + if cfg.OTel.Protocol != "grpc" { + t.Errorf("default otel protocol should be grpc, got %q", cfg.OTel.Protocol) + } + if cfg.OTel.ExportInterval != 60*time.Second { + t.Errorf("default otel export interval should be 60s, got %v", cfg.OTel.ExportInterval) + } +} +func TestValidateValidConfigs(t *testing.T) { + tests := []struct { + name string + cfg observability.MetricsConfig + }{ + {name: "disabled", cfg: observability.MetricsConfig{Enabled: false}}, + {name: "backend none", cfg: observability.MetricsConfig{Enabled: true, Backend: "none"}}, + {name: "backend empty", cfg: observability.MetricsConfig{Enabled: true, Backend: ""}}, + {name: "prometheus", cfg: observability.MetricsConfig{Enabled: true, Backend: "prometheus"}}, + { + name: "otel grpc", + cfg: observability.MetricsConfig{ + Enabled: true, Backend: "otel", + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, ExportInterval: 10 * time.Second}, + }, + }, + { + name: "otel http", + cfg: observability.MetricsConfig{ + Enabled: true, Backend: "otel", + OTel: observability.OTelConfig{Protocol: "http", Endpoint: "localhost:4318", ExportInterval: 30 * time.Second}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.cfg.Validate(); err != nil { + t.Errorf("unexpected validation error: %v", err) + } + }) + } +} + +func TestValidateInvalidConfigs(t *testing.T) { + tests := []struct { + name string + cfg observability.MetricsConfig + }{ + {name: "unknown backend", cfg: observability.MetricsConfig{Enabled: true, Backend: "datadog"}}, + { + name: "otel missing endpoint", + cfg: observability.MetricsConfig{ + Enabled: true, Backend: "otel", + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: "", ExportInterval: 10 * time.Second}, + }, + }, + { + name: "otel invalid protocol", + cfg: observability.MetricsConfig{ + Enabled: true, Backend: "otel", + OTel: observability.OTelConfig{Protocol: "tcp", Endpoint: otelGRPCEndpoint, ExportInterval: 10 * time.Second}, + }, + }, + { + name: "otel zero interval", + cfg: observability.MetricsConfig{ + Enabled: true, Backend: "otel", + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, ExportInterval: 0}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := tt.cfg.Validate(); err == nil { + t.Error("expected validation error but got nil") + } + }) + } +} + +func TestNewNoopBackend(t *testing.T) { + b, err := observability.New(observability.MetricsConfig{Enabled: false}) + if err != nil { + t.Fatalf(errUnexpectedFmt, err) + } + if b.HTTPHandler() != nil { + t.Error("noop backend HTTPHandler should return nil") + } +} + +func TestNewNoneBackend(t *testing.T) { + b, err := observability.New(observability.MetricsConfig{Enabled: true, Backend: "none"}) + if err != nil { + t.Fatalf(errUnexpectedFmt, err) + } + if b.HTTPHandler() != nil { + t.Error("none backend HTTPHandler should return nil") + } +} + +func TestNewPrometheusBackend(t *testing.T) { + cfg := observability.MetricsConfig{ + Enabled: true, Backend: "prometheus", + Prometheus: observability.PrometheusConfig{Path: defaultMetricsPath}, + } + b, err := observability.New(cfg) + if err != nil { + t.Fatalf(errUnexpectedFmt, err) + } + if b.HTTPHandler() == nil { + t.Error("prometheus backend HTTPHandler should not be nil") + } + if err := b.Shutdown(context.Background()); err != nil { + t.Errorf("prometheus shutdown error: %v", err) + } +} + +func TestNewInvalidBackend(t *testing.T) { + _, err := observability.New(observability.MetricsConfig{Enabled: true, Backend: "invalid"}) + if err == nil { + t.Error("expected error for invalid backend") + } +} + +func TestPrometheusAdapterAllInstruments(t *testing.T) { + b, err := observability.New(observability.MetricsConfig{ + Enabled: true, Backend: "prometheus", + Prometheus: observability.PrometheusConfig{Path: defaultMetricsPath}, + }) + if err != nil { + t.Fatalf("failed to create backend: %v", err) + } + ctx := context.Background() + labels := observability.Labels{"k": "v"} + + b.NewCounter("prom_adapter_counter_total", "desc", "k").Add(ctx, 1, labels) + b.NewUpDownCounter("prom_adapter_updown", "desc", "k").Add(ctx, 2, labels) + b.NewInt64Gauge("prom_adapter_int_gauge", "desc", "k").Record(ctx, 99, labels) + b.NewFloat64Gauge("prom_adapter_float_gauge", "desc", "k").Record(ctx, 1.23, labels) + b.NewHistogram("prom_adapter_histogram", "desc", []float64{0.1, 1.0}, "k").Record(ctx, 0.5, labels) + + if b.HTTPHandler() == nil { + t.Error("prometheus adapter HTTPHandler should not be nil") + } + if err := b.Shutdown(ctx); err != nil { + t.Errorf("Shutdown error: %v", err) + } +} + +func TestOtelAdapterAllInstruments(t *testing.T) { + b, err := observability.New(observability.MetricsConfig{ + Enabled: true, Backend: "otel", + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, Insecure: true, ExportInterval: 100 * time.Millisecond}, + }) + if err != nil { + t.Fatalf("failed to create otel backend: %v", err) + } + ctx := context.Background() + labels := observability.Labels{"k": "v"} + + b.NewCounter("otel_adapter_counter_total", "desc", "k").Add(ctx, 1, labels) + b.NewUpDownCounter("otel_adapter_updown", "desc", "k").Add(ctx, 2, labels) + b.NewInt64Gauge("otel_adapter_int_gauge", "desc", "k").Record(ctx, 99, labels) + b.NewFloat64Gauge("otel_adapter_float_gauge", "desc", "k").Record(ctx, 1.23, labels) + b.NewHistogram("otel_adapter_histogram", "desc", []float64{0.1, 1.0}, "k").Record(ctx, 0.5, labels) + + if b.HTTPHandler() != nil { + t.Error("OTel adapter HTTPHandler should be nil") + } + + shutdownCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + b.Shutdown(shutdownCtx) //nolint:errcheck +} diff --git a/internal/observability/noop.go b/internal/observability/noop.go new file mode 100644 index 0000000..47acc07 --- /dev/null +++ b/internal/observability/noop.go @@ -0,0 +1,71 @@ +package observability + +import ( + "context" + "net/http" +) + +// NoopBackend is a Backend that discards all observations. +// It is used when metrics are disabled (Enabled=false or Backend="none"). +// All methods are safe to call concurrently. +type NoopBackend struct{} + +// Compile-time interface check. +var _ Backend = (*NoopBackend)(nil) + +func (n *NoopBackend) NewCounter(_ string, _ string, _ ...string) Counter { + _ = n + return noopCounter{} +} + +func (n *NoopBackend) NewUpDownCounter(_ string, _ string, _ ...string) UpDownCounter { + _ = n + return noopUpDownCounter{} +} + +func (n *NoopBackend) NewInt64Gauge(_ string, _ string, _ ...string) Int64Gauge { + _ = n + return noopInt64Gauge{} +} + +func (n *NoopBackend) NewFloat64Gauge(_ string, _ string, _ ...string) Float64Gauge { + _ = n + return noopFloat64Gauge{} +} + +func (n *NoopBackend) NewHistogram(_ string, _ string, _ []float64, _ ...string) Histogram { + _ = n + return noopHistogram{} +} + +func (n *NoopBackend) HTTPHandler() http.Handler { + _ = n + return nil +} + +func (n *NoopBackend) Shutdown(_ context.Context) error { + _ = n + return nil +} + +// --- noop instrument types --- + +type noopCounter struct{} + +func (noopCounter) Add(_ context.Context, _ int64, _ Labels) { /* intentionally no-op */ } + +type noopUpDownCounter struct{} + +func (noopUpDownCounter) Add(_ context.Context, _ int64, _ Labels) { /* intentionally no-op */ } + +type noopInt64Gauge struct{} + +func (noopInt64Gauge) Record(_ context.Context, _ int64, _ Labels) { /* intentionally no-op */ } + +type noopFloat64Gauge struct{} + +func (noopFloat64Gauge) Record(_ context.Context, _ float64, _ Labels) { /* intentionally no-op */ } + +type noopHistogram struct{} + +func (noopHistogram) Record(_ context.Context, _ float64, _ Labels) { /* intentionally no-op */ } diff --git a/internal/observability/noop_test.go b/internal/observability/noop_test.go new file mode 100644 index 0000000..9496a0a --- /dev/null +++ b/internal/observability/noop_test.go @@ -0,0 +1,67 @@ +package observability_test + +import ( + "context" + "testing" + + "github.com/fosrl/gerbil/internal/observability" +) + +func TestNoopBackendAllInstruments(t *testing.T) { + n := &observability.NoopBackend{} + + ctx := context.Background() + labels := observability.Labels{"k": "v"} + + t.Run("Counter", func(_ *testing.T) { + c := n.NewCounter("test_counter", "desc") + c.Add(ctx, 1, labels) + c.Add(ctx, 0, nil) + }) + + t.Run("UpDownCounter", func(_ *testing.T) { + u := n.NewUpDownCounter("test_updown", "desc") + u.Add(ctx, 1, labels) + u.Add(ctx, -1, nil) + }) + + t.Run("Int64Gauge", func(_ *testing.T) { + g := n.NewInt64Gauge("test_int64gauge", "desc") + g.Record(ctx, 42, labels) + g.Record(ctx, 0, nil) + }) + + t.Run("Float64Gauge", func(_ *testing.T) { + g := n.NewFloat64Gauge("test_float64gauge", "desc") + g.Record(ctx, 3.14, labels) + g.Record(ctx, 0, nil) + }) + + t.Run("Histogram", func(_ *testing.T) { + h := n.NewHistogram("test_histogram", "desc", []float64{1, 5, 10}) + h.Record(ctx, 2.5, labels) + h.Record(ctx, 0, nil) + }) + + t.Run("HTTPHandler", func(t *testing.T) { + if n.HTTPHandler() != nil { + t.Error("noop HTTPHandler should be nil") + } + }) + + t.Run("Shutdown", func(t *testing.T) { + if err := n.Shutdown(ctx); err != nil { + t.Errorf("noop Shutdown should not error: %v", err) + } + }) +} + +func TestNoopBackendLabelNames(_ *testing.T) { + // Verify that label names passed at creation time are accepted without panic. + n := &observability.NoopBackend{} + n.NewCounter("c", "d", "label1", "label2") + n.NewUpDownCounter("u", "d", "l1") + n.NewInt64Gauge("g1", "d", "l1", "l2", "l3") + n.NewFloat64Gauge("g2", "d") + n.NewHistogram("h", "d", []float64{0.1, 1.0}, "l1") +} diff --git a/internal/observability/otel/backend.go b/internal/observability/otel/backend.go new file mode 100644 index 0000000..d3e3a23 --- /dev/null +++ b/internal/observability/otel/backend.go @@ -0,0 +1,210 @@ +// Package otel implements the OpenTelemetry metrics backend for Gerbil. +// +// Metrics are exported via OTLP (gRPC or HTTP) to an external collector. +// No Prometheus /metrics endpoint is exposed in this mode. +// Future OTel tracing and logging can be added alongside this package +// without touching the Prometheus-native path. +package otel + +import ( + "context" + "fmt" + "net/http" + "time" + + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" +) + +// Config holds OTel backend configuration. +type Config struct { + // Protocol is "grpc" (default) or "http". + Protocol string + + // Endpoint is the OTLP collector address. + Endpoint string + + // Insecure disables TLS. + Insecure bool + + // ExportInterval is the period between pushes to the collector. + ExportInterval time.Duration + + ServiceName string + ServiceVersion string + DeploymentEnvironment string +} + +// Backend is the OTel metrics backend. +type Backend struct { + cfg Config + provider *sdkmetric.MeterProvider + meter metric.Meter +} + +// New creates and initialises an OTel backend. +// +// cfg.Protocol must be "grpc" (default) or "http". +// cfg.Endpoint is the OTLP collector address (e.g. "localhost:4317"). +// cfg.ExportInterval sets the push period (defaults to 60 s if ≤ 0). +// cfg.Insecure disables TLS on the OTLP connection. +// +// Connection to the collector is established lazily; New only validates cfg +// and creates the SDK components. It returns an error only if the OTel resource +// or exporter cannot be constructed. +func New(cfg Config) (*Backend, error) { + if cfg.Protocol == "" { + cfg.Protocol = "grpc" + } + if cfg.ExportInterval <= 0 { + cfg.ExportInterval = 60 * time.Second + } + if cfg.ServiceName == "" { + cfg.ServiceName = "gerbil" + } + + res, err := newResource(cfg.ServiceName, cfg.ServiceVersion, cfg.DeploymentEnvironment) + if err != nil { + return nil, fmt.Errorf("otel backend: build resource: %w", err) + } + + exp, err := newExporter(context.Background(), cfg) + if err != nil { + return nil, fmt.Errorf("otel backend: create exporter: %w", err) + } + + reader := sdkmetric.NewPeriodicReader(exp, + sdkmetric.WithInterval(cfg.ExportInterval), + ) + + provider := sdkmetric.NewMeterProvider( + sdkmetric.WithResource(res), + sdkmetric.WithReader(reader), + ) + + meter := provider.Meter("github.com/fosrl/gerbil") + + return &Backend{cfg: cfg, provider: provider, meter: meter}, nil +} + +// HTTPHandler returns nil – the OTel backend does not expose an HTTP endpoint. +func (b *Backend) HTTPHandler() http.Handler { + _ = b + return nil +} + +// Shutdown flushes pending metrics and shuts down the MeterProvider. +func (b *Backend) Shutdown(ctx context.Context) error { + return b.provider.Shutdown(ctx) +} + +// NewCounter creates an OTel Int64Counter. +func (b *Backend) NewCounter(name, desc string, _ ...string) *Counter { + c, err := b.meter.Int64Counter(name, metric.WithDescription(desc)) + if err != nil { + panic(fmt.Sprintf("otel: create counter %q: %v", name, err)) + } + return &Counter{c: c} +} + +// NewUpDownCounter creates an OTel Int64UpDownCounter. +func (b *Backend) NewUpDownCounter(name, desc string, _ ...string) *UpDownCounter { + c, err := b.meter.Int64UpDownCounter(name, metric.WithDescription(desc)) + if err != nil { + panic(fmt.Sprintf("otel: create up-down counter %q: %v", name, err)) + } + return &UpDownCounter{c: c} +} + +// NewInt64Gauge creates an OTel Int64Gauge. +func (b *Backend) NewInt64Gauge(name, desc string, _ ...string) *Int64Gauge { + g, err := b.meter.Int64Gauge(name, metric.WithDescription(desc)) + if err != nil { + panic(fmt.Sprintf("otel: create int64 gauge %q: %v", name, err)) + } + return &Int64Gauge{g: g} +} + +// NewFloat64Gauge creates an OTel Float64Gauge. +func (b *Backend) NewFloat64Gauge(name, desc string, _ ...string) *Float64Gauge { + g, err := b.meter.Float64Gauge(name, metric.WithDescription(desc)) + if err != nil { + panic(fmt.Sprintf("otel: create float64 gauge %q: %v", name, err)) + } + return &Float64Gauge{g: g} +} + +// NewHistogram creates an OTel Float64Histogram with explicit bucket boundaries. +func (b *Backend) NewHistogram(name, desc string, buckets []float64, _ ...string) *Histogram { + h, err := b.meter.Float64Histogram(name, + metric.WithDescription(desc), + metric.WithExplicitBucketBoundaries(buckets...), + ) + if err != nil { + panic(fmt.Sprintf("otel: create histogram %q: %v", name, err)) + } + return &Histogram{h: h} +} + +// labelsToAttrs converts a Labels map to OTel attribute key-value pairs. +func labelsToAttrs(labels map[string]string) []attribute.KeyValue { + if len(labels) == 0 { + return nil + } + attrs := make([]attribute.KeyValue, 0, len(labels)) + for k, v := range labels { + attrs = append(attrs, attribute.String(k, v)) + } + return attrs +} + +// Counter wraps an OTel Int64Counter. +type Counter struct { + c metric.Int64Counter +} + +// Add increments the counter by value. +func (c *Counter) Add(ctx context.Context, value int64, labels map[string]string) { + c.c.Add(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) +} + +// UpDownCounter wraps an OTel Int64UpDownCounter. +type UpDownCounter struct { + c metric.Int64UpDownCounter +} + +// Add adjusts the up-down counter by value. +func (u *UpDownCounter) Add(ctx context.Context, value int64, labels map[string]string) { + u.c.Add(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) +} + +// Int64Gauge wraps an OTel Int64Gauge. +type Int64Gauge struct { + g metric.Int64Gauge +} + +// Record sets the gauge to value. +func (g *Int64Gauge) Record(ctx context.Context, value int64, labels map[string]string) { + g.g.Record(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) +} + +// Float64Gauge wraps an OTel Float64Gauge. +type Float64Gauge struct { + g metric.Float64Gauge +} + +// Record sets the gauge to value. +func (g *Float64Gauge) Record(ctx context.Context, value float64, labels map[string]string) { + g.g.Record(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) +} + +// Histogram wraps an OTel Float64Histogram. +type Histogram struct { + h metric.Float64Histogram +} + +// Record observes value in the histogram. +func (h *Histogram) Record(ctx context.Context, value float64, labels map[string]string) { + h.h.Record(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) +} diff --git a/internal/observability/otel/backend_test.go b/internal/observability/otel/backend_test.go new file mode 100644 index 0000000..e527678 --- /dev/null +++ b/internal/observability/otel/backend_test.go @@ -0,0 +1,141 @@ +package otel_test + +import ( + "context" + "testing" + "time" + + obsotel "github.com/fosrl/gerbil/internal/observability/otel" +) + +const ( + defaultGRPCEndpoint = "localhost:4317" + defaultServiceName = "gerbil-test" +) + +func newInMemoryBackend(t *testing.T) *obsotel.Backend { + t.Helper() + // Use a very short export interval; an in-process collector (noop exporter) + // is used by pointing to a non-existent endpoint with insecure mode. + // The backend itself should initialise without error since connection is lazy. + b, err := obsotel.New(obsotel.Config{ + Protocol: "grpc", + Endpoint: defaultGRPCEndpoint, + Insecure: true, + ExportInterval: 100 * time.Millisecond, + ServiceName: defaultServiceName, + ServiceVersion: "0.0.1", + }) + if err != nil { + t.Fatalf("failed to create otel backend: %v", err) + } + return b +} + +func TestOtelBackendHTTPHandlerIsNil(t *testing.T) { + b := newInMemoryBackend(t) + defer b.Shutdown(context.Background()) //nolint:errcheck + if b.HTTPHandler() != nil { + t.Error("OTel backend HTTPHandler should return nil") + } +} + +func TestOtelBackendShutdown(t *testing.T) { + b := newInMemoryBackend(t) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + if err := b.Shutdown(ctx); err != nil { + // Shutdown with unreachable collector may fail to flush; that's acceptable. + // What matters is that Shutdown does not panic. + t.Logf("Shutdown returned (expected with no collector): %v", err) + } +} + +func TestOtelBackendCounter(t *testing.T) { + b := newInMemoryBackend(t) + defer b.Shutdown(context.Background()) //nolint:errcheck + + c := b.NewCounter("gerbil_test_counter_total", "test counter", "result") + // Should not panic + c.Add(context.Background(), 1, map[string]string{"result": "ok"}) + c.Add(context.Background(), 5, nil) +} + +func TestOtelBackendUpDownCounter(t *testing.T) { + b := newInMemoryBackend(t) + defer b.Shutdown(context.Background()) //nolint:errcheck + + u := b.NewUpDownCounter("gerbil_test_updown", "test updown", "state") + u.Add(context.Background(), 3, map[string]string{"state": "active"}) + u.Add(context.Background(), -1, map[string]string{"state": "active"}) +} + +func TestOtelBackendInt64Gauge(t *testing.T) { + b := newInMemoryBackend(t) + defer b.Shutdown(context.Background()) //nolint:errcheck + + g := b.NewInt64Gauge("gerbil_test_int_gauge", "test gauge") + g.Record(context.Background(), 42, nil) +} + +func TestOtelBackendFloat64Gauge(t *testing.T) { + b := newInMemoryBackend(t) + defer b.Shutdown(context.Background()) //nolint:errcheck + + g := b.NewFloat64Gauge("gerbil_test_float_gauge", "test float gauge") + g.Record(context.Background(), 3.14, nil) +} + +func TestOtelBackendHistogram(t *testing.T) { + b := newInMemoryBackend(t) + defer b.Shutdown(context.Background()) //nolint:errcheck + + h := b.NewHistogram("gerbil_test_duration_seconds", "test histogram", + []float64{0.1, 0.5, 1.0}, "method") + h.Record(context.Background(), 0.3, map[string]string{"method": "GET"}) +} + +func TestOtelBackendHTTPProtocol(t *testing.T) { + b, err := obsotel.New(obsotel.Config{ + Protocol: "http", + Endpoint: "localhost:4318", + Insecure: true, + ExportInterval: 100 * time.Millisecond, + ServiceName: defaultServiceName, + }) + if err != nil { + t.Fatalf("failed to create otel http backend: %v", err) + } + defer b.Shutdown(context.Background()) //nolint:errcheck + + if b.HTTPHandler() != nil { + t.Error("OTel HTTP backend should not expose a /metrics endpoint") + } +} + +func TestOtelBackendInvalidProtocol(t *testing.T) { + _, err := obsotel.New(obsotel.Config{ + Protocol: "tcp", + Endpoint: defaultGRPCEndpoint, + ExportInterval: 10 * time.Second, + }) + if err == nil { + t.Error("expected error for invalid protocol") + } +} + +func TestOtelBackendDeploymentEnvironment(t *testing.T) { + b, err := obsotel.New(obsotel.Config{ + Protocol: "grpc", + Endpoint: defaultGRPCEndpoint, + Insecure: true, + ExportInterval: 100 * time.Millisecond, + ServiceName: defaultServiceName, + ServiceVersion: "1.2.3", + DeploymentEnvironment: "staging", + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer b.Shutdown(context.Background()) //nolint:errcheck +} diff --git a/internal/observability/otel/exporter.go b/internal/observability/otel/exporter.go new file mode 100644 index 0000000..44fe1e2 --- /dev/null +++ b/internal/observability/otel/exporter.go @@ -0,0 +1,50 @@ +package otel + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" + "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" + sdkmetric "go.opentelemetry.io/otel/sdk/metric" +) + +// newExporter creates the appropriate OTLP exporter based on cfg.Protocol. +func newExporter(ctx context.Context, cfg Config) (sdkmetric.Exporter, error) { + switch cfg.Protocol { + case "grpc", "": + return newGRPCExporter(ctx, cfg) + case "http": + return newHTTPExporter(ctx, cfg) + default: + return nil, fmt.Errorf("otel: unknown protocol %q (must be \"grpc\" or \"http\")", cfg.Protocol) + } +} + +func newGRPCExporter(ctx context.Context, cfg Config) (sdkmetric.Exporter, error) { + opts := []otlpmetricgrpc.Option{ + otlpmetricgrpc.WithEndpoint(cfg.Endpoint), + } + if cfg.Insecure { + opts = append(opts, otlpmetricgrpc.WithInsecure()) + } + exp, err := otlpmetricgrpc.New(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("otlp grpc exporter: %w", err) + } + return exp, nil +} + +func newHTTPExporter(ctx context.Context, cfg Config) (sdkmetric.Exporter, error) { + opts := []otlpmetrichttp.Option{ + otlpmetrichttp.WithEndpoint(cfg.Endpoint), + } + if cfg.Insecure { + opts = append(opts, otlpmetrichttp.WithInsecure()) + } + exp, err := otlpmetrichttp.New(ctx, opts...) + if err != nil { + return nil, fmt.Errorf("otlp http exporter: %w", err) + } + return exp, nil +} diff --git a/internal/observability/otel/resource.go b/internal/observability/otel/resource.go new file mode 100644 index 0000000..47a14ff --- /dev/null +++ b/internal/observability/otel/resource.go @@ -0,0 +1,25 @@ +package otel + +import ( + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/sdk/resource" + semconv "go.opentelemetry.io/otel/semconv/v1.40.0" +) + +// newResource builds an OTel resource for the Gerbil service. +func newResource(serviceName, serviceVersion, deploymentEnv string) (*resource.Resource, error) { + attrs := []attribute.KeyValue{ + semconv.ServiceName(serviceName), + } + if serviceVersion != "" { + attrs = append(attrs, semconv.ServiceVersion(serviceVersion)) + } + if deploymentEnv != "" { + attrs = append(attrs, semconv.DeploymentEnvironmentName(deploymentEnv)) + } + + return resource.Merge( + resource.Default(), + resource.NewWithAttributes(semconv.SchemaURL, attrs...), + ) +} diff --git a/internal/observability/prometheus/backend.go b/internal/observability/prometheus/backend.go new file mode 100644 index 0000000..f2744c1 --- /dev/null +++ b/internal/observability/prometheus/backend.go @@ -0,0 +1,185 @@ +// Package prometheus implements the native Prometheus metrics backend for Gerbil. +// +// This backend uses the Prometheus Go client directly; it does NOT depend on the +// OpenTelemetry SDK. A dedicated Prometheus registry is used so that default +// Go/process metrics are not unintentionally included unless the caller opts in. +package prometheus + +import ( + "context" + "net/http" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/collectors" + "github.com/prometheus/client_golang/prometheus/promhttp" +) + +// Config holds Prometheus-backend configuration. +type Config struct { + // Path is the HTTP endpoint path (e.g. "/metrics"). + Path string + + // IncludeGoMetrics controls whether the standard Go runtime and process + // collectors are registered on the dedicated registry. + // Defaults to true if not explicitly set. + IncludeGoMetrics *bool +} + +// Backend is the native Prometheus metrics backend. +// Metric instruments are created via the New* family of methods and stored +// in the backend-specific instrument types that implement the observability +// instrument interfaces. +type Backend struct { + cfg Config + registry *prometheus.Registry + handler http.Handler +} + +// New creates and initialises a Prometheus backend. +// +// cfg.Path sets the HTTP endpoint path (defaults to "/metrics" if empty). +// cfg.IncludeGoMetrics controls whether standard Go runtime and process metrics +// are included; defaults to true when nil. +// +// Returns an error if the registry cannot be created. +func New(cfg Config) (*Backend, error) { + if cfg.Path == "" { + cfg.Path = "/metrics" + } + + registry := prometheus.NewRegistry() + + // Include Go and process metrics by default. + includeGo := cfg.IncludeGoMetrics == nil || *cfg.IncludeGoMetrics + if includeGo { + registry.MustRegister( + collectors.NewGoCollector(), + collectors.NewProcessCollector(collectors.ProcessCollectorOpts{}), + ) + } + + handler := promhttp.HandlerFor(registry, promhttp.HandlerOpts{ + EnableOpenMetrics: false, + }) + + return &Backend{cfg: cfg, registry: registry, handler: handler}, nil +} + +// HTTPHandler returns the Prometheus /metrics HTTP handler. +func (b *Backend) HTTPHandler() http.Handler { + return b.handler +} + +// Shutdown is a no-op for the Prometheus backend. +// The registry does not maintain background goroutines. +func (b *Backend) Shutdown(_ context.Context) error { + _ = b + return nil +} + +// NewCounter creates a Prometheus CounterVec registered on the backend's registry. +func (b *Backend) NewCounter(name, desc string, labelNames ...string) *Counter { + vec := prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: name, + Help: desc, + }, labelNames) + b.registry.MustRegister(vec) + return &Counter{vec: vec} +} + +// NewUpDownCounter creates a Prometheus GaugeVec (Prometheus gauges are +// bidirectional) registered on the backend's registry. +func (b *Backend) NewUpDownCounter(name, desc string, labelNames ...string) *UpDownCounter { + vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: name, + Help: desc, + }, labelNames) + b.registry.MustRegister(vec) + return &UpDownCounter{vec: vec} +} + +// NewInt64Gauge creates a Prometheus GaugeVec registered on the backend's registry. +func (b *Backend) NewInt64Gauge(name, desc string, labelNames ...string) *Int64Gauge { + vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: name, + Help: desc, + }, labelNames) + b.registry.MustRegister(vec) + return &Int64Gauge{vec: vec} +} + +// NewFloat64Gauge creates a Prometheus GaugeVec registered on the backend's registry. +func (b *Backend) NewFloat64Gauge(name, desc string, labelNames ...string) *Float64Gauge { + vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ + Name: name, + Help: desc, + }, labelNames) + b.registry.MustRegister(vec) + return &Float64Gauge{vec: vec} +} + +// NewHistogram creates a Prometheus HistogramVec registered on the backend's registry. +func (b *Backend) NewHistogram(name, desc string, buckets []float64, labelNames ...string) *Histogram { + vec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: name, + Help: desc, + Buckets: buckets, + }, labelNames) + b.registry.MustRegister(vec) + return &Histogram{vec: vec} +} + +// Counter is a native Prometheus counter instrument. +type Counter struct { + vec *prometheus.CounterVec +} + +// Add increments the counter by value for the given labels. +// +// value must be non-negative. Negative values are ignored. +func (c *Counter) Add(_ context.Context, value int64, labels map[string]string) { + if value < 0 { + return + } + c.vec.With(prometheus.Labels(labels)).Add(float64(value)) +} + +// UpDownCounter is a native Prometheus gauge used as a bidirectional counter. +type UpDownCounter struct { + vec *prometheus.GaugeVec +} + +// Add adjusts the gauge by value for the given labels. +func (u *UpDownCounter) Add(_ context.Context, value int64, labels map[string]string) { + u.vec.With(prometheus.Labels(labels)).Add(float64(value)) +} + +// Int64Gauge is a native Prometheus gauge recording integer snapshot values. +type Int64Gauge struct { + vec *prometheus.GaugeVec +} + +// Record sets the gauge to value for the given labels. +func (g *Int64Gauge) Record(_ context.Context, value int64, labels map[string]string) { + g.vec.With(prometheus.Labels(labels)).Set(float64(value)) +} + +// Float64Gauge is a native Prometheus gauge recording float snapshot values. +type Float64Gauge struct { + vec *prometheus.GaugeVec +} + +// Record sets the gauge to value for the given labels. +func (g *Float64Gauge) Record(_ context.Context, value float64, labels map[string]string) { + g.vec.With(prometheus.Labels(labels)).Set(value) +} + +// Histogram is a native Prometheus histogram instrument. +type Histogram struct { + vec *prometheus.HistogramVec +} + +// Record observes value for the given labels. +func (h *Histogram) Record(_ context.Context, value float64, labels map[string]string) { + h.vec.With(prometheus.Labels(labels)).Observe(value) +} diff --git a/internal/observability/prometheus/backend_test.go b/internal/observability/prometheus/backend_test.go new file mode 100644 index 0000000..d60821f --- /dev/null +++ b/internal/observability/prometheus/backend_test.go @@ -0,0 +1,173 @@ +package prometheus_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + obsprom "github.com/fosrl/gerbil/internal/observability/prometheus" +) + +func newTestBackend(t *testing.T) *obsprom.Backend { + t.Helper() + b, err := obsprom.New(obsprom.Config{Path: "/metrics"}) + if err != nil { + t.Fatalf("failed to create prometheus backend: %v", err) + } + return b +} + +func TestPrometheusBackendHTTPHandler(t *testing.T) { + b := newTestBackend(t) + if b.HTTPHandler() == nil { + t.Error("HTTPHandler should not be nil") + } +} + +func TestPrometheusBackendShutdown(t *testing.T) { + b := newTestBackend(t) + if err := b.Shutdown(context.Background()); err != nil { + t.Errorf("Shutdown returned error: %v", err) + } +} + +func TestPrometheusBackendCounter(t *testing.T) { + b := newTestBackend(t) + c := b.NewCounter("test_counter_total", "A test counter", "result") + c.Add(context.Background(), 3, map[string]string{"result": "ok"}) + + body := scrapeMetrics(t, b) + assertMetricPresent(t, body, `test_counter_total{result="ok"} 3`) +} + +func TestPrometheusBackendUpDownCounter(t *testing.T) { + b := newTestBackend(t) + u := b.NewUpDownCounter("test_gauge_total", "A test up-down counter", "state") + u.Add(context.Background(), 5, map[string]string{"state": "active"}) + u.Add(context.Background(), -2, map[string]string{"state": "active"}) + + body := scrapeMetrics(t, b) + assertMetricPresent(t, body, `test_gauge_total{state="active"} 3`) +} + +func TestPrometheusBackendInt64Gauge(t *testing.T) { + b := newTestBackend(t) + g := b.NewInt64Gauge("test_int_gauge", "An integer gauge", "ifname") + g.Record(context.Background(), 42, map[string]string{"ifname": "wg0"}) + + body := scrapeMetrics(t, b) + assertMetricPresent(t, body, `test_int_gauge{ifname="wg0"} 42`) +} + +func TestPrometheusBackendFloat64Gauge(t *testing.T) { + b := newTestBackend(t) + g := b.NewFloat64Gauge("test_float_gauge", "A float gauge", "cert") + g.Record(context.Background(), 7.5, map[string]string{"cert": "example.com"}) + + body := scrapeMetrics(t, b) + assertMetricPresent(t, body, `test_float_gauge{cert="example.com"} 7.5`) +} + +func TestPrometheusBackendHistogram(t *testing.T) { + b := newTestBackend(t) + buckets := []float64{0.1, 0.5, 1.0, 5.0} + h := b.NewHistogram("test_duration_seconds", "A test histogram", buckets, "method") + h.Record(context.Background(), 0.3, map[string]string{"method": "GET"}) + + body := scrapeMetrics(t, b) + if !strings.Contains(body, "test_duration_seconds") { + t.Errorf("expected histogram metric in output, body:\n%s", body) + } +} + +func TestPrometheusBackendMultipleLabels(t *testing.T) { + b := newTestBackend(t) + c := b.NewCounter("multi_label_total", "Multi-label counter", "method", "route", "status_code") + c.Add(context.Background(), 1, map[string]string{ + "method": "POST", + "route": "/api/peers", + "status_code": "200", + }) + + body := scrapeMetrics(t, b) + if !strings.Contains(body, "multi_label_total") { + t.Errorf("expected multi_label_total in output, body:\n%s", body) + } +} + +func TestPrometheusBackendGoMetrics(t *testing.T) { + b := newTestBackend(t) + body := scrapeMetrics(t, b) + // Default backend includes Go runtime metrics. + if !strings.Contains(body, "go_goroutines") { + t.Error("expected go_goroutines in default backend output") + } +} + +func TestPrometheusBackendNoGoMetrics(t *testing.T) { + f := false + b, err := obsprom.New(obsprom.Config{IncludeGoMetrics: &f}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + body := scrapeMetrics(t, b) + if strings.Contains(body, "go_goroutines") { + t.Error("expected no go_goroutines when IncludeGoMetrics=false") + } +} + +func TestPrometheusBackendNilLabels(t *testing.T) { + // Adding with nil labels should not panic (treated as empty map). + b := newTestBackend(t) + c := b.NewCounter("nil_labels_total", "counter with no labels") + // nil labels with no label names declared should be safe + c.Add(context.Background(), 1, nil) +} + +func TestPrometheusBackendConcurrentAdd(t *testing.T) { + b := newTestBackend(t) + c := b.NewCounter("concurrent_total", "concurrent counter", "worker") + + done := make(chan struct{}) + for i := 0; i < 10; i++ { + go func(_ int) { + for j := 0; j < 100; j++ { + c.Add(context.Background(), 1, map[string]string{"worker": "w"}) + } + done <- struct{}{} + }(i) + } + for i := 0; i < 10; i++ { + <-done + } + + body := scrapeMetrics(t, b) + assertMetricPresent(t, body, `concurrent_total{worker="w"} 1000`) +} + +// --- helpers --- + +func scrapeMetrics(t *testing.T, b *obsprom.Backend) string { + t.Helper() + req := httptest.NewRequest(http.MethodGet, "/metrics", http.NoBody) + rr := httptest.NewRecorder() + b.HTTPHandler().ServeHTTP(rr, req) + if rr.Code != http.StatusOK { + t.Fatalf("metrics handler returned %d", rr.Code) + } + body, err := io.ReadAll(rr.Body) + if err != nil { + t.Fatalf("failed to read response body: %v", err) + } + return string(body) +} + +func assertMetricPresent(t *testing.T, body, expected string) { + t.Helper() + if !strings.Contains(body, expected) { + t.Errorf("expected %q in metrics output\nbody:\n%s", expected, body) + } +} From 4357ddf64b9b5c32409548b547c5a7af26e36429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Fri, 3 Apr 2026 15:57:53 +0200 Subject: [PATCH 11/25] Integrate metrics instrumentation across core services --- main.go | 220 ++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 207 insertions(+), 13 deletions(-) diff --git a/main.go b/main.go index 695de55..62bfe7c 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "syscall" "time" + "github.com/fosrl/gerbil/internal/metrics" "github.com/fosrl/gerbil/logger" "github.com/fosrl/gerbil/proxy" "github.com/fosrl/gerbil/relay" @@ -101,6 +102,35 @@ type UpdateDestinationsRequest struct { Destinations []relay.PeerDestination `json:"destinations"` } +// httpMetricsMiddleware wraps HTTP handlers with metrics tracking +func httpMetricsMiddleware(endpoint string, handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + + // Create a response writer wrapper to capture status code + ww := &responseWriterWrapper{ResponseWriter: w, statusCode: http.StatusOK} + + // Call the actual handler + handler(ww, r) + + // Record metrics + duration := time.Since(startTime).Seconds() + metrics.RecordHTTPRequest(endpoint, r.Method, fmt.Sprintf("%d", ww.statusCode)) + metrics.RecordHTTPRequestDuration(endpoint, r.Method, duration) + } +} + +// responseWriterWrapper wraps http.ResponseWriter to capture status code +type responseWriterWrapper struct { + http.ResponseWriter + statusCode int +} + +func (w *responseWriterWrapper) WriteHeader(statusCode int) { + w.statusCode = statusCode + w.ResponseWriter.WriteHeader(statusCode) +} + func parseLogLevel(level string) logger.LogLevel { switch strings.ToUpper(level) { case "DEBUG": @@ -136,6 +166,15 @@ func main() { localOverridesStr string trustedUpstreamsStr string proxyProtocol bool + + // Metrics configuration variables (set from env, then overridden by CLI flags) + metricsEnabled bool + metricsBackend string + metricsPath string + otelMetricsProtocol string + otelMetricsEndpoint string + otelMetricsInsecure bool + otelMetricsExportInterval time.Duration ) interfaceName = os.Getenv("INTERFACE") @@ -157,6 +196,40 @@ func main() { doTrafficShapingStr := os.Getenv("DO_TRAFFIC_SHAPING") bandwidthLimitStr := os.Getenv("BANDWIDTH_LIMIT") + // Read metrics env vars (defaults applied by DefaultMetricsConfig; these override defaults). + metricsEnabled = true // default + if v := os.Getenv("METRICS_ENABLED"); v != "" { + metricsEnabled = strings.ToLower(v) == "true" + } + metricsBackend = "prometheus" // default + if v := os.Getenv("METRICS_BACKEND"); v != "" { + metricsBackend = v + } + metricsPath = "/metrics" // default + if v := os.Getenv("METRICS_PATH"); v != "" { + metricsPath = v + } + otelMetricsProtocol = "grpc" // default + if v := os.Getenv("OTEL_METRICS_PROTOCOL"); v != "" { + otelMetricsProtocol = v + } + otelMetricsEndpoint = "localhost:4317" // default + if v := os.Getenv("OTEL_METRICS_ENDPOINT"); v != "" { + otelMetricsEndpoint = v + } + otelMetricsInsecure = true // default + if v := os.Getenv("OTEL_METRICS_INSECURE"); v != "" { + otelMetricsInsecure = strings.ToLower(v) == "true" + } + otelMetricsExportInterval = 60 * time.Second // default + if v := os.Getenv("OTEL_METRICS_EXPORT_INTERVAL"); v != "" { + if d, err2 := time.ParseDuration(v); err2 == nil { + otelMetricsExportInterval = d + } else { + log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2) + } + } + if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") } @@ -241,6 +314,15 @@ func main() { flag.StringVar(&bandwidthLimit, "bandwidth-limit", "50mbit", "Bandwidth limit per peer for traffic shaping (e.g. 50mbit, 1gbit)") } + // Metrics CLI flags – always registered so that CLI overrides env/defaults. + flag.BoolVar(&metricsEnabled, "metrics-enabled", metricsEnabled, "Enable metrics collection (default: true)") + flag.StringVar(&metricsBackend, "metrics-backend", metricsBackend, "Metrics backend: prometheus, otel, or none") + flag.StringVar(&metricsPath, "metrics-path", metricsPath, "HTTP path for Prometheus /metrics endpoint") + flag.StringVar(&otelMetricsProtocol, "otel-metrics-protocol", otelMetricsProtocol, "OTLP transport protocol: grpc or http") + flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)") + flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection") + flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes") + flag.Parse() // Derive IFB device name from the WireGuard interface name (Linux limit: 15 chars) @@ -252,6 +334,38 @@ func main() { logger.Init() logger.GetLogger().SetLevel(parseLogLevel(logLevel)) + // Initialize metrics with the selected backend. + // Config precedence: CLI flags > env vars > defaults (already applied above). + metricsHandler, err := metrics.Initialize(metrics.Config{ + Enabled: metricsEnabled, + Backend: metricsBackend, + Prometheus: metrics.PrometheusConfig{ + Path: metricsPath, + }, + OTel: metrics.OTelConfig{ + Protocol: otelMetricsProtocol, + Endpoint: otelMetricsEndpoint, + Insecure: otelMetricsInsecure, + ExportInterval: otelMetricsExportInterval, + }, + ServiceName: "gerbil", + ServiceVersion: "1.0.0", + DeploymentEnvironment: os.Getenv("DEPLOYMENT_ENVIRONMENT"), + }) + if err != nil { + logger.Fatal("Failed to initialize metrics: %v", err) + } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := metrics.Shutdown(shutdownCtx); err != nil { + logger.Error("Failed to shutdown metrics: %v", err) + } + }() + + // Record restart metric + metrics.RecordRestart() + // Base context for the application; cancel on SIGINT/SIGTERM ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -420,18 +534,27 @@ func main() { logger.Fatal("Failed to start proxy: %v", err) } - // Set up HTTP server - http.HandleFunc("/peer", handlePeer) - http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) - http.HandleFunc("/update-destinations", handleUpdateDestinations) - http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) - http.HandleFunc("/healthz", handleHealthz) + // Set up HTTP server with metrics middleware + http.HandleFunc("/peer", httpMetricsMiddleware("peer", handlePeer)) + http.HandleFunc("/update-proxy-mapping", httpMetricsMiddleware("update_proxy_mapping", handleUpdateProxyMapping)) + http.HandleFunc("/update-destinations", httpMetricsMiddleware("update_destinations", handleUpdateDestinations)) + http.HandleFunc("/update-local-snis", httpMetricsMiddleware("update_local_snis", handleUpdateLocalSNIs)) + http.HandleFunc("/healthz", httpMetricsMiddleware("healthz", handleHealthz)) + + // Register metrics endpoint only for Prometheus backend. + // OTel backend pushes to a collector; no /metrics endpoint needed. + if metricsHandler != nil { + http.Handle(metricsPath, metricsHandler) + logger.Info("Metrics endpoint enabled at %s", metricsPath) + } + logger.Info("Starting HTTP server on %s", listenAddr) // HTTP server with graceful shutdown on context cancel server := &http.Server{ - Addr: listenAddr, - Handler: nil, + Addr: listenAddr, + Handler: nil, + ReadHeaderTimeout: 3 * time.Second, } group.Go(func() error { // http.ErrServerClosed is returned on graceful shutdown; not an error for us @@ -466,26 +589,35 @@ func main() { func loadRemoteConfig(url string, key wgtypes.Key, reachableAt string) (WgConfig, error) { var body *bytes.Buffer if reachableAt == "" { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s"}`, key.PublicKey().String()))) + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q}`, key.PublicKey().String()))) } else { - body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": "%s", "reachableAt": "%s"}`, key.PublicKey().String(), reachableAt))) + body = bytes.NewBuffer([]byte(fmt.Sprintf(`{"publicKey": %q, "reachableAt": %q}`, key.PublicKey().String(), reachableAt))) } resp, err := http.Post(url, "application/json", body) if err != nil { // print the error logger.Error("Error fetching remote config %s: %v", url, err) + // Record remote config fetch error + metrics.RecordRemoteConfigFetch("error") return WgConfig{}, err } defer resp.Body.Close() data, err := io.ReadAll(resp.Body) if err != nil { + metrics.RecordRemoteConfigFetch("error") return WgConfig{}, err } var config WgConfig err = json.Unmarshal(data, &config) + if err != nil { + metrics.RecordRemoteConfigFetch("error") + return config, err + } + // Record successful remote config fetch + metrics.RecordRemoteConfigFetch("success") return config, err } @@ -593,6 +725,10 @@ func ensureWireguardInterface(wgconfig WgConfig) error { logger.Info("WireGuard interface %s created and configured", interfaceName) + // Record interface state metric + hostname, _ := os.Hostname() + metrics.RecordInterfaceUp(interfaceName, hostname, true) + return nil } @@ -890,15 +1026,22 @@ func handleAddPeer(w http.ResponseWriter, r *http.Request) { var peer Peer if err := json.NewDecoder(r.Body).Decode(&peer); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) + // Record peer add error + metrics.RecordPeerOperation("add", "error") return } err := addPeer(peer) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) + // Record peer add error + metrics.RecordPeerOperation("add", "error") return } + // Record peer add success + metrics.RecordPeerOperation("add", "success") + // Notify if notifyURL is set go notifyPeerChange("add", peer.PublicKey) @@ -971,6 +1114,10 @@ func addPeerInternal(peer Peer) error { logger.Info("Peer %s added successfully", peer.PublicKey) + // Record metrics + metrics.RecordPeersTotal(interfaceName, 1) + metrics.RecordAllowedIPsCount(interfaceName, peer.PublicKey, int64(len(peer.AllowedIPs))) + return nil } @@ -978,15 +1125,22 @@ func handleRemovePeer(w http.ResponseWriter, r *http.Request) { publicKey := r.URL.Query().Get("public_key") if publicKey == "" { http.Error(w, "Missing public_key query parameter", http.StatusBadRequest) + // Record peer remove error + metrics.RecordPeerOperation("remove", "error") return } err := removePeer(publicKey) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) + // Record peer remove error + metrics.RecordPeerOperation("remove", "error") return } + // Record peer remove success + metrics.RecordPeerOperation("remove", "success") + // Notify if notifyURL is set go notifyPeerChange("remove", publicKey) @@ -1052,6 +1206,10 @@ func removePeerInternal(publicKey string) error { logger.Info("Peer %s removed successfully", publicKey) + // Record metrics + metrics.RecordPeersTotal(interfaceName, -1) + metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs))) + return nil } @@ -1086,6 +1244,8 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) + // Record error + metrics.RecordProxyMappingUpdateRequest("error") return } @@ -1096,6 +1256,9 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { update.OldDestination.DestinationIP, update.OldDestination.DestinationPort, update.NewDestination.DestinationIP, update.NewDestination.DestinationPort) + // Record success + metrics.RecordProxyMappingUpdateRequest("success") + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]interface{}{ "status": "Proxy mappings updated successfully", @@ -1156,6 +1319,8 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) + // Record error + metrics.RecordDestinationsUpdateRequest("error") return } @@ -1164,6 +1329,9 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { logger.Info("Updated proxy mapping for %s:%d with %d destinations", request.SourceIP, request.SourcePort, len(request.Destinations)) + // Record success + metrics.RecordDestinationsUpdateRequest("success") + w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(map[string]interface{}{ "status": "Destinations updated successfully", @@ -1225,7 +1393,7 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { return nil, fmt.Errorf("failed to get device: %v", err) } - peerBandwidths := []PeerBandwidth{} + var peerBandwidths []PeerBandwidth now := time.Now() mu.Lock() @@ -1266,6 +1434,14 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { bytesInMB := bytesInDiff / (1024 * 1024) bytesOutMB := bytesOutDiff / (1024 * 1024) + // Record metrics (in bytes) + if bytesInDiff > 0 { + metrics.RecordBytesReceived(interfaceName, publicKey, int64(bytesInDiff)) + } + if bytesOutDiff > 0 { + metrics.RecordBytesTransmitted(interfaceName, publicKey, int64(bytesOutDiff)) + } + peerBandwidths = append(peerBandwidths, PeerBandwidth{ PublicKey: publicKey, BytesIn: bytesInMB, @@ -1305,24 +1481,31 @@ func calculatePeerBandwidth() ([]PeerBandwidth, error) { func reportPeerBandwidth(apiURL string) error { bandwidths, err := calculatePeerBandwidth() if err != nil { + // Record bandwidth report error + metrics.RecordBandwidthReport("error") return fmt.Errorf("failed to calculate peer bandwidth: %v", err) } jsonData, err := json.Marshal(bandwidths) if err != nil { + metrics.RecordBandwidthReport("error") return fmt.Errorf("failed to marshal bandwidth data: %v", err) } resp, err := http.Post(apiURL, "application/json", bytes.NewBuffer(jsonData)) if err != nil { + metrics.RecordBandwidthReport("error") return fmt.Errorf("failed to send bandwidth data: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { + metrics.RecordBandwidthReport("error") return fmt.Errorf("API returned non-OK status: %s", resp.Status) } + // Record successful bandwidth report + metrics.RecordBandwidthReport("success") return nil } @@ -1356,14 +1539,25 @@ func monitorMemory(limit uint64) { for { runtime.ReadMemStats(&m) if m.Alloc > limit { + // Determine severity based on how much over the limit + severity := "warning" + if m.Alloc > limit*2 { + severity = "critical" + } + fmt.Printf("Memory spike detected (%d bytes). Dumping profile...\n", m.Alloc) + // Record memory spike metric + metrics.RecordMemorySpike(severity) + f, err := os.Create(fmt.Sprintf("/var/config/heap/heap-spike-%d.pprof", time.Now().Unix())) if err != nil { log.Println("could not create profile:", err) } else { pprof.WriteHeapProfile(f) f.Close() + // Record heap profile written metric + metrics.RecordHeapProfileWritten() } // Wait a while before checking again to avoid spamming profiles @@ -1522,7 +1716,7 @@ func setupPeerBandwidthLimit(peerIP string) error { if output, err := cmd.CombinedOutput(); err != nil { logger.Warn("Failed to add egress filter for peer IP %s: %v, output: %s", ip, err, string(output)) } - + // Set up ingress shaping on the IFB device (peer's upload / ingress on wg0). // All wg0 ingress is redirected to ifb0 by ensureIFBDevice; we add a per-peer // class + src filter here so each peer gets its own independent rate limit. @@ -1626,7 +1820,7 @@ func removePeerBandwidthLimit(peerIP string) error { logger.Warn("Failed to remove egress class for peer IP %s: %v, output: %s", ip, err, string(output)) } } - + // Remove the ingress class and filters on the IFB device ifbClassID := fmt.Sprintf("2:%s", lastOctet) From e47a57cb4f62cf956fc2ceb60c0b3c78ed74e316 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Fri, 3 Apr 2026 18:15:41 +0200 Subject: [PATCH 12/25] Enhance metrics tracking in SNIProxy connection handling --- proxy/proxy.go | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index f29878e..d53cdad 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -16,6 +16,7 @@ import ( "sync" "time" + "github.com/fosrl/gerbil/internal/metrics" "github.com/fosrl/gerbil/logger" "github.com/patrickmn/go-cache" ) @@ -487,6 +488,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { defer p.wg.Done() defer clientConn.Close() + metrics.RecordSNIConnection("accepted") + logger.Debug("Accepted connection from %s", clientConn.RemoteAddr()) // Check for PROXY protocol from trusted upstream @@ -497,10 +500,12 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { var err error proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn) if err != nil { + metrics.RecordSNIProxyProtocolParseError() logger.Debug("Failed to parse PROXY protocol: %v", err) return } if proxyInfo != nil { + metrics.RecordSNITrustedProxyEvent("proxy_protocol_parsed") logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort) } else { @@ -517,11 +522,13 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { } // Extract SNI hostname + clientHelloStart := time.Now() hostname, clientReader, err := p.extractSNI(actualClientConn) if err != nil { logger.Debug("SNI extraction failed: %v", err) return } + metrics.RecordProxyTLSHandshake(hostname, time.Since(clientHelloStart).Seconds()) if hostname == "" { log.Println("No SNI hostname found") @@ -569,6 +576,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { defer targetConn.Close() logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort) + metrics.RecordActiveProxyConnection(hostname, 1) + defer metrics.RecordActiveProxyConnection(hostname, -1) // Send PROXY protocol header if enabled if p.proxyProtocol { @@ -618,7 +627,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { }() // Start bidirectional data transfer - p.pipe(actualClientConn, targetConn, clientReader) + p.pipe(hostname, actualClientConn, targetConn, clientReader) } // getRoute retrieves routing information for a hostname @@ -626,6 +635,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check local overrides first if _, isOverride := p.localOverrides[hostname]; isOverride { logger.Debug("Local override matched for hostname: %s", hostname) + metrics.RecordProxyRouteLookup("local_override", hostname) return &RouteRecord{ Hostname: hostname, TargetHost: p.localProxyAddr, @@ -638,6 +648,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { _, isLocal := p.localSNIs[hostname] p.localSNIsLock.RUnlock() if isLocal { + metrics.RecordProxyRouteLookup("local", hostname) return &RouteRecord{ Hostname: hostname, TargetHost: p.localProxyAddr, @@ -648,13 +659,16 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check cache first if cached, found := p.cache.Get(hostname); found { if cached == nil { + metrics.RecordProxyRouteLookup("cached_not_found", hostname) return nil, nil // Cached negative result } logger.Debug("Cache hit for hostname: %s", hostname) + metrics.RecordProxyRouteLookup("cache_hit", hostname) return cached.(*RouteRecord), nil } logger.Debug("Cache miss for hostname: %s, querying API", hostname) + metrics.RecordProxyRouteLookup("cache_miss", hostname) // Query API with timeout ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) @@ -682,22 +696,28 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { req.Header.Set("Content-Type", "application/json") // Make HTTP request + apiStart := time.Now() client := &http.Client{Timeout: 5 * time.Second} resp, err := client.Do(req) if err != nil { + metrics.RecordSNIRouteAPIRequest("error") return nil, fmt.Errorf("API request failed: %w", err) } defer resp.Body.Close() + metrics.RecordSNIRouteAPILatency(time.Since(apiStart).Seconds()) if resp.StatusCode == http.StatusNotFound { + metrics.RecordSNIRouteAPIRequest("not_found") // Cache negative result for shorter time (1 minute) p.cache.Set(hostname, nil, 1*time.Minute) return nil, nil } if resp.StatusCode != http.StatusOK { + metrics.RecordSNIRouteAPIRequest("error") return nil, fmt.Errorf("API returned status %d", resp.StatusCode) } + metrics.RecordSNIRouteAPIRequest("success") // Parse response var apiResponse RouteAPIResponse @@ -754,7 +774,7 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s } // pipe handles bidirectional data transfer between connections -func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { +func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, clientReader io.Reader) { var wg sync.WaitGroup wg.Add(2) @@ -775,7 +795,8 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) // Use a large buffer for better performance buf := make([]byte, 32*1024) - _, err := io.CopyBuffer(targetConn, clientReader, buf) + bytesCopied, err := io.CopyBuffer(targetConn, clientReader, buf) + metrics.RecordProxyBytesTransmitted(hostname, "client_to_target", bytesCopied) if err != nil && err != io.EOF { logger.Debug("Copy client->target error: %v", err) } @@ -788,7 +809,8 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) // Use a large buffer for better performance buf := make([]byte, 32*1024) - _, err := io.CopyBuffer(clientConn, targetConn, buf) + bytesCopied, err := io.CopyBuffer(clientConn, targetConn, buf) + metrics.RecordProxyBytesTransmitted(hostname, "target_to_client", bytesCopied) if err != nil && err != io.EOF { logger.Debug("Copy target->client error: %v", err) } From 652d9c5c6833be8933557746e0f7cdb021decaf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Fri, 3 Apr 2026 18:15:58 +0200 Subject: [PATCH 13/25] Add metrics tracking for UDP packet handling and session management --- relay/relay.go | 115 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 101 insertions(+), 14 deletions(-) diff --git a/relay/relay.go b/relay/relay.go index 0ab5930..c065c29 100644 --- a/relay/relay.go +++ b/relay/relay.go @@ -13,12 +13,15 @@ import ( "sync" "time" + "github.com/fosrl/gerbil/internal/metrics" "github.com/fosrl/gerbil/logger" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/curve25519" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +const relayIfname = "relay" + type EncryptedHolePunchMessage struct { EphemeralPublicKey string `json:"ephemeralPublicKey"` Nonce []byte `json:"nonce"` @@ -290,9 +293,13 @@ func (s *UDPProxyServer) packetWorker() { for packet := range s.packetChan { // Determine packet type by inspecting the first byte. if packet.n > 0 && packet.data[0] >= 1 && packet.data[0] <= 4 { + metrics.RecordUDPPacket(relayIfname, "wireguard", "in") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(packet.n)) // Process as a WireGuard packet. s.handleWireGuardPacket(packet.data, packet.remoteAddr) } else { + metrics.RecordUDPPacket(relayIfname, "hole_punch", "in") + metrics.RecordUDPPacketSize(relayIfname, "hole_punch", float64(packet.n)) // Rate limit: allow at most 2 hole punch messages per IP:Port per second rateLimitKey := packet.remoteAddr.String() entryVal, _ := s.holePunchRateLimiter.LoadOrStore(rateLimitKey, &holePunchRateLimitEntry{ @@ -310,6 +317,7 @@ func (s *UDPProxyServer) packetWorker() { rlEntry.mu.Unlock() if !allowed { // logger.Debug("Rate limiting hole punch message from %s", rateLimitKey) + metrics.RecordHolePunchEvent(relayIfname, "rate_limited") bufferPool.Put(packet.data[:1500]) continue } @@ -318,6 +326,7 @@ func (s *UDPProxyServer) packetWorker() { var encMsg EncryptedHolePunchMessage if err := json.Unmarshal(packet.data, &encMsg); err != nil { logger.Error("Error unmarshaling encrypted message: %v", err) + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -325,6 +334,7 @@ func (s *UDPProxyServer) packetWorker() { if encMsg.EphemeralPublicKey == "" { logger.Error("Received malformed message without ephemeral key") + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -334,6 +344,7 @@ func (s *UDPProxyServer) packetWorker() { decryptedData, err := s.decryptMessage(encMsg) if err != nil { // logger.Error("Failed to decrypt message: %v", err) + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -343,6 +354,7 @@ func (s *UDPProxyServer) packetWorker() { var msg HolePunchMessage if err := json.Unmarshal(decryptedData, &msg); err != nil { logger.Error("Error unmarshaling decrypted message: %v", err) + metrics.RecordHolePunchEvent(relayIfname, "error") // Return the buffer to the pool for reuse and continue with next packet bufferPool.Put(packet.data[:1500]) continue @@ -362,6 +374,7 @@ func (s *UDPProxyServer) packetWorker() { logger.Debug("Created endpoint from packet remoteAddr %s: IP=%s, Port=%d", packet.remoteAddr.String(), endpoint.IP, endpoint.Port) s.notifyServer(endpoint) s.clearSessionsForIP(endpoint.IP) // Clear sessions for this IP to allow re-establishment + metrics.RecordHolePunchEvent(relayIfname, "success") } // Return the buffer to the pool for reuse. bufferPool.Put(packet.data[:1500]) @@ -429,6 +442,8 @@ func (s *UDPProxyServer) fetchInitialMappings() error { mapping.LastUsed = time.Now() s.proxyMappings.Store(key, mapping) } + metrics.RecordProxyInitialMappings(relayIfname, int64(len(initialMappings.Mappings))) + metrics.RecordProxyMapping(relayIfname, int64(len(initialMappings.Mappings))) logger.Info("Loaded %d initial proxy mappings", len(initialMappings.Mappings)) return nil } @@ -544,7 +559,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward handshake initiation: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } case WireGuardMessageTypeHandshakeResponse: @@ -556,12 +575,17 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD sessionKey := fmt.Sprintf("%d:%d", receiverIndex, senderIndex) // Store the session information - s.wgSessions.Store(sessionKey, &WireGuardSession{ + session := &WireGuardSession{ ReceiverIndex: receiverIndex, SenderIndex: senderIndex, DestAddr: remoteAddr, LastSeen: time.Now(), - }) + } + if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded { + s.wgSessions.Store(sessionKey, session) + } else { + metrics.RecordSession(relayIfname, 1) + } // Forward the response to the original sender for _, dest := range proxyMapping.Destinations { @@ -580,7 +604,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Error("Failed to forward handshake response: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } case WireGuardMessageTypeTransportData: @@ -617,7 +645,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward transport data: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + return } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } else { // No known session, fall back to forwarding to all peers logger.Debug("No session found for receiver index %d, forwarding to all destinations", receiverIndex) @@ -640,7 +672,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Debug("Failed to forward transport data: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } } @@ -665,7 +701,11 @@ func (s *UDPProxyServer) handleWireGuardPacket(packet []byte, remoteAddr *net.UD _, err = conn.Write(packet) if err != nil { logger.Error("Failed to forward WireGuard packet: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(len(packet))) } } } @@ -683,6 +723,7 @@ func (s *UDPProxyServer) getOrCreateConnection(destAddr *net.UDPAddr, remoteAddr // Create new connection newConn, err := net.DialUDP("udp", nil, destAddr) if err != nil { + metrics.RecordProxyConnectionError(relayIfname, "dial_udp") return nil, fmt.Errorf("failed to create UDP connection: %v", err) } @@ -706,6 +747,8 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd logger.Debug("Error reading response from %s: %v", destAddr.String(), err) return } + metrics.RecordUDPPacket(relayIfname, "wireguard", "in") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n)) // Process the response to track sessions if it's a WireGuard packet if n > 0 && buffer[0] >= 1 && buffer[0] <= 4 { @@ -713,12 +756,17 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd if ok && buffer[0] == WireGuardMessageTypeHandshakeResponse { // Store the session mapping for the handshake response sessionKey := fmt.Sprintf("%d:%d", senderIndex, receiverIndex) - s.wgSessions.Store(sessionKey, &WireGuardSession{ + session := &WireGuardSession{ ReceiverIndex: receiverIndex, SenderIndex: senderIndex, DestAddr: destAddr, LastSeen: time.Now(), - }) + } + if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded { + s.wgSessions.Store(sessionKey, session) + } else { + metrics.RecordSession(relayIfname, 1) + } logger.Debug("Stored session mapping: %s -> %s", sessionKey, destAddr.String()) } else if ok && buffer[0] == WireGuardMessageTypeTransportData { // Track communication pattern for session rebuilding (reverse direction) @@ -730,7 +778,11 @@ func (s *UDPProxyServer) handleResponses(conn *net.UDPConn, destAddr *net.UDPAdd _, err = s.conn.WriteToUDP(buffer[:n], remoteAddr) if err != nil { logger.Error("Failed to forward response: %v", err) + metrics.RecordProxyConnectionError(relayIfname, "write_udp") + continue } + metrics.RecordUDPPacket(relayIfname, "wireguard", "out") + metrics.RecordUDPPacketSize(relayIfname, "wireguard", float64(n)) } } @@ -741,15 +793,18 @@ func (s *UDPProxyServer) cleanupIdleConnections() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.connections.Range(func(key, value interface{}) bool { destConn := value.(*DestinationConn) if now.Sub(destConn.lastUsed) > 10*time.Minute { destConn.conn.Close() s.connections.Delete(key) + metrics.RecordProxyCleanupRemoved(relayIfname, "conn", 1) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "conn", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return } @@ -764,16 +819,20 @@ func (s *UDPProxyServer) cleanupIdleSessions() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.wgSessions.Range(func(key, value interface{}) bool { session := value.(*WireGuardSession) // Use thread-safe method to read LastSeen if now.Sub(session.GetLastSeen()) > 15*time.Minute { s.wgSessions.Delete(key) + metrics.RecordSession(relayIfname, -1) + metrics.RecordProxyCleanupRemoved(relayIfname, "session", 1) logger.Debug("Removed idle session: %s", key) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "session", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return } @@ -787,16 +846,20 @@ func (s *UDPProxyServer) cleanupIdleProxyMappings() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.proxyMappings.Range(func(key, value interface{}) bool { mapping := value.(ProxyMapping) // Remove mappings that haven't been used in 30 minutes if now.Sub(mapping.LastUsed) > 30*time.Minute { s.proxyMappings.Delete(key) + metrics.RecordProxyMapping(relayIfname, -1) + metrics.RecordProxyCleanupRemoved(relayIfname, "proxy_mapping", 1) logger.Debug("Removed idle proxy mapping: %s", key) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "proxy_mapping", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return } @@ -839,6 +902,11 @@ func (s *UDPProxyServer) notifyServer(endpoint ClientEndpoint) { key := fmt.Sprintf("%s:%d", endpoint.IP, endpoint.Port) logger.Debug("About to store proxy mapping with key: %s (from endpoint IP=%s, Port=%d)", key, endpoint.IP, endpoint.Port) mapping.LastUsed = time.Now() + if _, existed := s.proxyMappings.Load(key); existed { + metrics.RecordProxyMappingUpdate(relayIfname) + } else { + metrics.RecordProxyMapping(relayIfname, 1) + } s.proxyMappings.Store(key, mapping) logger.Debug("Stored proxy mapping for %s with %d destinations (timestamp: %v)", key, len(mapping.Destinations), mapping.LastUsed) @@ -851,6 +919,11 @@ func (s *UDPProxyServer) UpdateProxyMapping(sourceIP string, sourcePort int, des Destinations: destinations, LastUsed: time.Now(), } + if _, existed := s.proxyMappings.Load(key); existed { + metrics.RecordProxyMappingUpdate(relayIfname) + } else { + metrics.RecordProxyMapping(relayIfname, 1) + } s.proxyMappings.Store(key, mapping) } @@ -917,6 +990,10 @@ func (s *UDPProxyServer) clearSessionsForIP(ip string) { for _, key := range keysToDelete { s.wgSessions.Delete(key) } + if len(keysToDelete) > 0 { + metrics.RecordSession(relayIfname, -int64(len(keysToDelete))) + metrics.RecordProxyCleanupRemoved(relayIfname, "session", int64(len(keysToDelete))) + } logger.Debug("Cleared %d sessions for WG IP: %s", len(keysToDelete), ip) } @@ -1077,7 +1154,9 @@ func (s *UDPProxyServer) trackCommunicationPattern(fromAddr, toAddr *net.UDPAddr pattern.LastFromDest = now } - s.commPatterns.Store(patternKey, pattern) + if _, loaded := s.commPatterns.LoadOrStore(patternKey, pattern); !loaded { + metrics.RecordCommPattern(relayIfname, 1) + } } } @@ -1095,16 +1174,20 @@ func (s *UDPProxyServer) tryRebuildSession(pattern *CommunicationPattern) { sessionKey := fmt.Sprintf("%d:%d", pattern.DestIndex, pattern.ClientIndex) // Check if we already have this session - if _, exists := s.wgSessions.Load(sessionKey); !exists { - s.wgSessions.Store(sessionKey, &WireGuardSession{ - ReceiverIndex: pattern.DestIndex, - SenderIndex: pattern.ClientIndex, - DestAddr: pattern.ToDestination, - LastSeen: time.Now(), - }) - logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", - sessionKey, pattern.ToDestination.String(), pattern.PacketCount) + session := &WireGuardSession{ + ReceiverIndex: pattern.DestIndex, + SenderIndex: pattern.ClientIndex, + DestAddr: pattern.ToDestination, + LastSeen: time.Now(), } + if _, loaded := s.wgSessions.LoadOrStore(sessionKey, session); loaded { + s.wgSessions.Store(sessionKey, session) + } else { + metrics.RecordSession(relayIfname, 1) + metrics.RecordSessionRebuilt(relayIfname) + } + logger.Info("Rebuilt WireGuard session from communication pattern: %s -> %s (packets: %d)", + sessionKey, pattern.ToDestination.String(), pattern.PacketCount) } } @@ -1139,6 +1222,7 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { for { select { case <-ticker.C: + cleanupStart := time.Now() now := time.Now() s.commPatterns.Range(func(key, value interface{}) bool { pattern := value.(*CommunicationPattern) @@ -1152,10 +1236,13 @@ func (s *UDPProxyServer) cleanupIdleCommunicationPatterns() { // Remove patterns that haven't had activity in 20 minutes if now.Sub(lastActivity) > 20*time.Minute { s.commPatterns.Delete(key) + metrics.RecordCommPattern(relayIfname, -1) + metrics.RecordProxyCleanupRemoved(relayIfname, "comm_pattern", 1) logger.Debug("Removed idle communication pattern: %s", key) } return true }) + metrics.RecordProxyIdleCleanupDuration(relayIfname, "comm_pattern", time.Since(cleanupStart).Seconds()) case <-s.ctx.Done(): return } From f07c83fde4314ee02ce51acab89c9e6ee77862ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Fri, 3 Apr 2026 18:41:40 +0200 Subject: [PATCH 14/25] Update Go version to 1.25.0 and add gRPC dependency --- go.mod | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/go.mod b/go.mod index ecae7a6..fa157a3 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/fosrl/gerbil -go 1.25 +go 1.25.0 require ( github.com/patrickmn/go-cache v2.1.0+incompatible @@ -15,7 +15,6 @@ require ( golang.org/x/crypto v0.48.0 golang.org/x/sync v0.19.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 - google.golang.org/grpc v1.79.3 ) require ( @@ -46,5 +45,6 @@ require ( golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect + google.golang.org/grpc v1.79.3 // indirect google.golang.org/protobuf v1.36.11 // indirect ) From eedd813e2fed895ce3def66d8e08bad75bca68f4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Sat, 4 Apr 2026 01:19:47 +0200 Subject: [PATCH 15/25] Update Go version in GitHub Actions workflow --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2466b7d..82ce101 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,7 @@ jobs: - name: Set up Go uses: actions/setup-go@4b73464bb391d4059bd26b0524d20df3927bd417 # v6.3.0 with: - go-version: 1.25 + go-version: 1.26 - name: Build go run: go build From 58415dee7e6979b3bff4a1cc422f81e52c54b7cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Sat, 4 Apr 2026 01:44:59 +0200 Subject: [PATCH 16/25] refactor: remove redundant HTTP client instantiation in getRoute method --- proxy/proxy.go | 1 - 1 file changed, 1 deletion(-) diff --git a/proxy/proxy.go b/proxy/proxy.go index da93c42..71cf4ed 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -717,7 +717,6 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Make HTTP request apiStart := time.Now() - client := &http.Client{Timeout: 5 * time.Second} // Make HTTP request using reusable client resp, err := p.httpClient.Do(req) if err != nil { From b642df3e1e00553a1aca67558dab65b85210987b Mon Sep 17 00:00:00 2001 From: Owen Date: Tue, 7 Apr 2026 11:33:29 -0400 Subject: [PATCH 17/25] Add CODEOWNERS --- .github/CODEOWNERS | 1 + 1 file changed, 1 insertion(+) create mode 100644 .github/CODEOWNERS diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..c5f1403 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @oschwartz10612 @miloschwartz From 4a322a436bd0502231cbce1dbe26ba5ed4425018 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:12:16 +0200 Subject: [PATCH 18/25] chore(docs): add OTLP timeout docs and minor fixes --- .github/CODEOWNERS | 1 + README.md | 2 +- docs/observability.md | 12 ++++++++---- examples/otel-collector-config.yaml | 3 ++- 4 files changed, 12 insertions(+), 6 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c5f1403..7d8c330 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1 +1,2 @@ * @oschwartz10612 @miloschwartz +internal/observability/** @marcschaeferger diff --git a/README.md b/README.md index e403d88..6c324c1 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ make ### Binary -Make sure to have Go 1.23.1 installed. +Make sure to have Go 1.26 installed. ```bash make local diff --git a/docs/observability.md b/docs/observability.md index 13cb038..64184d5 100644 --- a/docs/observability.md +++ b/docs/observability.md @@ -83,6 +83,7 @@ type OTelConfig struct { Endpoint string // default: "localhost:4317" Insecure bool // default: true ExportInterval time.Duration // default: 60s + Timeout time.Duration // default: 10s } ``` @@ -97,6 +98,7 @@ type OTelConfig struct { | `OTEL_METRICS_ENDPOINT` | `localhost:4317` | OTLP collector address | | `OTEL_METRICS_INSECURE` | `true` | Disable TLS for OTLP | | `OTEL_METRICS_EXPORT_INTERVAL` | `60s` | Push interval (e.g. `10s`, `1m`) | +| `OTEL_METRICS_TIMEOUT` | `10s` | Timeout for OTLP exporter connection setup | | `DEPLOYMENT_ENVIRONMENT` | _(unset)_ | OTel deployment.environment attribute | ### CLI flags @@ -108,7 +110,8 @@ type OTelConfig struct { --otel-metrics-protocol string (default: grpc) --otel-metrics-endpoint string (default: localhost:4317) --otel-metrics-insecure bool (default: true) ---otel-metrics-export-interval duration (default: 1m0s) +--otel-metrics-export-interval duration (default: 60s) +--otel-metrics-timeout duration (default: 10s) ``` --- @@ -164,6 +167,7 @@ export OTEL_METRICS_PROTOCOL=grpc export OTEL_METRICS_ENDPOINT=otel-collector:4317 export OTEL_METRICS_INSECURE=true export OTEL_METRICS_EXPORT_INTERVAL=10s +export OTEL_METRICS_TIMEOUT=10s export DEPLOYMENT_ENVIRONMENT=production ``` @@ -176,6 +180,7 @@ export DEPLOYMENT_ENVIRONMENT=production --otel-metrics-endpoint=otel-collector:4317 \ --otel-metrics-insecure \ --otel-metrics-export-interval=10s \ + --otel-metrics-timeout=10s \ --config=/etc/gerbil/config.json ``` @@ -225,7 +230,6 @@ All metrics use the prefix `gerbil__`. | Metric | Type | Labels | |--------|------|--------| | `gerbil_proxy_mapping_active` | UpDownCounter | `ifname` | -| `gerbil_session_active` | UpDownCounter | `ifname` | | `gerbil_active_sessions` | UpDownCounter | `ifname` | | `gerbil_udp_packets_total` | Counter | `ifname`, `type`, `direction` | | `gerbil_hole_punch_events_total` | Counter | `ifname`, `result` | @@ -256,7 +260,7 @@ The `docker-compose.metrics.yml` provides a complete observability stack. **Prometheus mode:** ```bash -METRICS_BACKEND=prometheus docker-compose -f docker-compose.metrics.yml up -d +METRICS_BACKEND=prometheus docker-compose -f docker compose.metrics.yml up -d # Scrape at http://localhost:3003/metrics # Grafana at http://localhost:3000 (admin/admin) ``` @@ -265,5 +269,5 @@ METRICS_BACKEND=prometheus docker-compose -f docker-compose.metrics.yml up -d ```bash METRICS_BACKEND=otel OTEL_METRICS_ENDPOINT=otel-collector:4317 \ - docker-compose -f docker-compose.metrics.yml up -d + docker compose -f docker-compose.metrics.yml up -d ``` diff --git a/examples/otel-collector-config.yaml b/examples/otel-collector-config.yaml index 5c85356..acfa434 100644 --- a/examples/otel-collector-config.yaml +++ b/examples/otel-collector-config.yaml @@ -1,3 +1,4 @@ +file_format: '1.0' receivers: otlp: protocols: @@ -43,4 +44,4 @@ service: metrics: receivers: [otlp] processors: [batch, resource] - exporters: [prometheus, prometheusremotewrite, debug] \ No newline at end of file + exporters: [prometheus, prometheusremotewrite, debug] From f130a7cdb80ac50f1bf806f8febfa597cc4294db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:12:22 +0200 Subject: [PATCH 19/25] chore(deps): update OpenTelemetry and related modules --- go.mod | 24 ++++++++++++------------ go.sum | 52 ++++++++++++++++++++++++++-------------------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/go.mod b/go.mod index ace40d4..bd2fe30 100644 --- a/go.mod +++ b/go.mod @@ -6,12 +6,12 @@ require ( github.com/patrickmn/go-cache v2.1.0+incompatible github.com/prometheus/client_golang v1.20.5 github.com/vishvananda/netlink v1.3.1 - go.opentelemetry.io/otel v1.42.0 - go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 - go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 - go.opentelemetry.io/otel/metric v1.42.0 - go.opentelemetry.io/otel/sdk v1.42.0 - go.opentelemetry.io/otel/sdk/metric v1.42.0 + go.opentelemetry.io/otel v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 + go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0 + go.opentelemetry.io/otel/metric v1.43.0 + go.opentelemetry.io/otel/sdk v1.43.0 + go.opentelemetry.io/otel/sdk/metric v1.43.0 golang.org/x/crypto v0.49.0 golang.org/x/sync v0.20.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 @@ -37,14 +37,14 @@ require ( github.com/prometheus/procfs v0.15.1 // indirect github.com/vishvananda/netns v0.0.5 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/otel/trace v1.42.0 // indirect - go.opentelemetry.io/proto/otlp v1.9.0 // indirect - golang.org/x/net v0.51.0 // indirect + go.opentelemetry.io/otel/trace v1.43.0 // indirect + go.opentelemetry.io/proto/otlp v1.10.0 // indirect + golang.org/x/net v0.52.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 // indirect - google.golang.org/grpc v1.79.3 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect + google.golang.org/grpc v1.80.0 // indirect google.golang.org/protobuf v1.36.11 // indirect ) diff --git a/go.sum b/go.sum index 20a033f..91c4b4f 100644 --- a/go.sum +++ b/go.sum @@ -55,26 +55,26 @@ github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zd github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= -go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0 h1:MdKucPl/HbzckWWEisiNqMPhRrAOQX8r4jTuGr636gk= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.42.0/go.mod h1:RolT8tWtfHcjajEH5wFIZ4Dgh5jpPdFXYV9pTAk/qjc= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0 h1:H7O6RlGOMTizyl3R08Kn5pdM06bnH8oscSj7o11tmLA= -go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.42.0/go.mod h1:mBFWu/WOVDkWWsR7Tx7h6EpQB8wsv7P0Yrh0Pb7othc= -go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= -go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= -go.opentelemetry.io/otel/sdk v1.42.0 h1:LyC8+jqk6UJwdrI/8VydAq/hvkFKNHZVIWuslJXYsDo= -go.opentelemetry.io/otel/sdk v1.42.0/go.mod h1:rGHCAxd9DAph0joO4W6OPwxjNTYWghRWmkHuGbayMts= -go.opentelemetry.io/otel/sdk/metric v1.42.0 h1:D/1QR46Clz6ajyZ3G8SgNlTJKBdGp84q9RKCAZ3YGuA= -go.opentelemetry.io/otel/sdk/metric v1.42.0/go.mod h1:Ua6AAlDKdZ7tdvaQKfSmnFTdHx37+J4ba8MwVCYM5hc= -go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= -go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= -go.opentelemetry.io/proto/otlp v1.9.0 h1:l706jCMITVouPOqEnii2fIAuO3IVGBRPV5ICjceRb/A= -go.opentelemetry.io/proto/otlp v1.9.0/go.mod h1:xE+Cx5E/eEHw+ISFkwPLwCZefwVjY+pqKg1qcK03+/4= +go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= +go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0 h1:8UQVDcZxOJLtX6gxtDt3vY2WTgvZqMQRzjsqiIHQdkc= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc v1.43.0/go.mod h1:2lmweYCiHYpEjQ/lSJBYhj9jP1zvCvQW4BqL9dnT7FQ= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0 h1:w1K+pCJoPpQifuVpsKamUdn9U0zM3xUziVOqsGksUrY= +go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp v1.43.0/go.mod h1:HBy4BjzgVE8139ieRI75oXm3EcDN+6GhD88JT1Kjvxg= +go.opentelemetry.io/otel/metric v1.43.0 h1:d7638QeInOnuwOONPp4JAOGfbCEpYb+K6DVWvdxGzgM= +go.opentelemetry.io/otel/metric v1.43.0/go.mod h1:RDnPtIxvqlgO8GRW18W6Z/4P462ldprJtfxHxyKd2PY= +go.opentelemetry.io/otel/sdk v1.43.0 h1:pi5mE86i5rTeLXqoF/hhiBtUNcrAGHLKQdhg4h4V9Dg= +go.opentelemetry.io/otel/sdk v1.43.0/go.mod h1:P+IkVU3iWukmiit/Yf9AWvpyRDlUeBaRg6Y+C58QHzg= +go.opentelemetry.io/otel/sdk/metric v1.43.0 h1:S88dyqXjJkuBNLeMcVPRFXpRw2fuwdvfCGLEo89fDkw= +go.opentelemetry.io/otel/sdk/metric v1.43.0/go.mod h1:C/RJtwSEJ5hzTiUz5pXF1kILHStzb9zFlIEe85bhj6A= +go.opentelemetry.io/otel/trace v1.43.0 h1:BkNrHpup+4k4w+ZZ86CZoHHEkohws8AY+WTX09nk+3A= +go.opentelemetry.io/otel/trace v1.43.0/go.mod h1:/QJhyVBUUswCphDVxq+8mld+AvhXZLhe+8WVFxiFff0= +go.opentelemetry.io/proto/otlp v1.10.0 h1:IQRWgT5srOCYfiWnpqUYz9CVmbO8bFmKcwYxpuCSL2g= +go.opentelemetry.io/proto/otlp v1.10.0/go.mod h1:/CV4QoCR/S9yaPj8utp3lvQPoqMtxXdzn7ozvvozVqk= golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= -golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= -golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -87,14 +87,14 @@ golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b/go.mod h1:tqur9LnfstdR9ep2LaJT4lFUl0EjlHtge+gAjmsHUG4= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57 h1:JLQynH/LBHfCTSbDWl+py8C+Rg/k1OVH3xfcaiANuF0= -google.golang.org/genproto/googleapis/api v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:kSJwQxqmFXeo79zOmbrALdflXQeAYcUbgS7PbpMknCY= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57 h1:mWPCjDEyshlQYzBpMNHaEof6UX1PmHcaUODUywQ0uac= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260209200024-4cfbd4190f57/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= -google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= -google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9 h1:VPWxll4HlMw1Vs/qXtN7BvhZqsS9cdAittCNvVENElA= +google.golang.org/genproto/googleapis/api v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:7QBABkRtR8z+TEnmXTqIqwJLlzrZKVfAUm7tY3yGv0M= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= From bcb5cc4746b64e1d8969ef92a9076b898a32672c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:12:23 +0200 Subject: [PATCH 20/25] refactor(metrics): safe initialization and instrument factory --- internal/metrics/metrics.go | 546 +++++++++++++++++++++++++++---- internal/metrics/metrics_test.go | 12 +- 2 files changed, 482 insertions(+), 76 deletions(-) diff --git a/internal/metrics/metrics.go b/internal/metrics/metrics.go index 4a92b9f..e38aa90 100644 --- a/internal/metrics/metrics.go +++ b/internal/metrics/metrics.go @@ -7,7 +7,9 @@ package metrics import ( "context" + "fmt" "net/http" + "sync" "github.com/fosrl/gerbil/internal/observability" ) @@ -24,6 +26,7 @@ type OTelConfig = observability.OTelConfig var ( backend observability.Backend + initMu sync.Mutex // Interface and peer metrics wgInterfaceUp observability.Int64Gauge @@ -55,7 +58,6 @@ var ( udpPacketSizeBytes observability.Histogram holePunchEventsTotal observability.Counter proxyMappingActive observability.UpDownCounter - sessionActive observability.UpDownCounter sessionRebuiltTotal observability.Counter commPatternActive observability.UpDownCounter proxyCleanupRemovedTotal observability.Counter @@ -107,6 +109,13 @@ func DefaultConfig() Config { // Initialize sets up the metrics system using the selected backend. // It returns the /metrics HTTP handler (non-nil only for Prometheus backend). func Initialize(cfg Config) (http.Handler, error) { + initMu.Lock() + defer initMu.Unlock() + + if backend != nil { + return backend.HTTPHandler(), nil + } + b, err := observability.New(cfg) if err != nil { return nil, err @@ -114,6 +123,7 @@ func Initialize(cfg Config) (http.Handler, error) { backend = b if err := createInstruments(); err != nil { + backend = nil return nil, err } @@ -122,8 +132,13 @@ func Initialize(cfg Config) (http.Handler, error) { // Shutdown gracefully shuts down the metrics backend. func Shutdown(ctx context.Context) error { - if backend != nil { - return backend.Shutdown(ctx) + initMu.Lock() + b := backend + backend = nil + initMu.Unlock() + + if b != nil { + return b.Shutdown(ctx) } return nil } @@ -135,129 +150,346 @@ func createInstruments() error { b := backend - wgInterfaceUp = b.NewInt64Gauge("gerbil_wg_interface_up", + newCounter := func(name, desc string, labelNames ...string) (observability.Counter, error) { + c, err := b.NewCounter(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create counter %q: %w", name, err) + } + return c, nil + } + + newUpDownCounter := func(name, desc string, labelNames ...string) (observability.UpDownCounter, error) { + c, err := b.NewUpDownCounter(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create updown counter %q: %w", name, err) + } + return c, nil + } + + newInt64Gauge := func(name, desc string, labelNames ...string) (observability.Int64Gauge, error) { + g, err := b.NewInt64Gauge(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create int64 gauge %q: %w", name, err) + } + return g, nil + } + + newFloat64Gauge := func(name, desc string, labelNames ...string) (observability.Float64Gauge, error) { + g, err := b.NewFloat64Gauge(name, desc, labelNames...) + if err != nil { + return nil, fmt.Errorf("create float64 gauge %q: %w", name, err) + } + return g, nil + } + + newHistogram := func(name, desc string, buckets []float64, labelNames ...string) (observability.Histogram, error) { + h, err := b.NewHistogram(name, desc, buckets, labelNames...) + if err != nil { + return nil, fmt.Errorf("create histogram %q: %w", name, err) + } + return h, nil + } + + var err error + + wgInterfaceUp, err = newInt64Gauge("gerbil_wg_interface_up", "Operational state of a WireGuard interface (1=up, 0=down)", "ifname", "instance") - wgPeersTotal = b.NewUpDownCounter("gerbil_wg_peers_total", + if err != nil { + return err + } + wgPeersTotal, err = newUpDownCounter("gerbil_wg_peers_total", "Total number of configured peers per interface", "ifname") - wgPeerConnected = b.NewInt64Gauge("gerbil_wg_peer_connected", + if err != nil { + return err + } + wgPeerConnected, err = newInt64Gauge("gerbil_wg_peer_connected", "Whether a specific peer is connected (1=connected, 0=disconnected)", "ifname", "peer") - allowedIPsCount = b.NewUpDownCounter("gerbil_allowed_ips_count", + if err != nil { + return err + } + allowedIPsCount, err = newUpDownCounter("gerbil_allowed_ips_count", "Number of allowed IPs configured per peer", "ifname", "peer") - keyRotationTotal = b.NewCounter("gerbil_key_rotation_total", + if err != nil { + return err + } + keyRotationTotal, err = newCounter("gerbil_key_rotation_total", "Key rotation events", "ifname", "reason") - wgHandshakesTotal = b.NewCounter("gerbil_wg_handshakes_total", + if err != nil { + return err + } + wgHandshakesTotal, err = newCounter("gerbil_wg_handshakes_total", "Count of handshake attempts with their result status", "ifname", "peer", "result") - wgHandshakeLatency = b.NewHistogram("gerbil_wg_handshake_latency_seconds", + if err != nil { + return err + } + wgHandshakeLatency, err = newHistogram("gerbil_wg_handshake_latency_seconds", "Distribution of handshake latencies in seconds", durationBuckets, "ifname", "peer") - wgPeerRTT = b.NewHistogram("gerbil_wg_peer_rtt_seconds", + if err != nil { + return err + } + wgPeerRTT, err = newHistogram("gerbil_wg_peer_rtt_seconds", "Observed round-trip time to a peer in seconds", durationBuckets, "ifname", "peer") - wgBytesReceived = b.NewCounter("gerbil_wg_bytes_received_total", + if err != nil { + return err + } + wgBytesReceived, err = newCounter("gerbil_wg_bytes_received_total", "Number of bytes received from a peer", "ifname", "peer") - wgBytesTransmitted = b.NewCounter("gerbil_wg_bytes_transmitted_total", + if err != nil { + return err + } + wgBytesTransmitted, err = newCounter("gerbil_wg_bytes_transmitted_total", "Number of bytes transmitted to a peer", "ifname", "peer") - netlinkEventsTotal = b.NewCounter("gerbil_netlink_events_total", + if err != nil { + return err + } + netlinkEventsTotal, err = newCounter("gerbil_netlink_events_total", "Number of netlink events processed", "event_type") - netlinkErrorsTotal = b.NewCounter("gerbil_netlink_errors_total", + if err != nil { + return err + } + netlinkErrorsTotal, err = newCounter("gerbil_netlink_errors_total", "Count of netlink or kernel errors", "component", "error_type") - syncDuration = b.NewHistogram("gerbil_sync_duration_seconds", + if err != nil { + return err + } + syncDuration, err = newHistogram("gerbil_sync_duration_seconds", "Duration of reconciliation/sync loops in seconds", durationBuckets, "component") - workqueueDepth = b.NewUpDownCounter("gerbil_workqueue_depth", + if err != nil { + return err + } + workqueueDepth, err = newUpDownCounter("gerbil_workqueue_depth", "Current length of internal work queues", "queue") - kernelModuleLoads = b.NewCounter("gerbil_kernel_module_loads_total", + if err != nil { + return err + } + kernelModuleLoads, err = newCounter("gerbil_kernel_module_loads_total", "Count of kernel module load attempts", "result") - firewallRulesApplied = b.NewCounter("gerbil_firewall_rules_applied_total", + if err != nil { + return err + } + firewallRulesApplied, err = newCounter("gerbil_firewall_rules_applied_total", "IPTables/NFT rules applied", "result", "chain") - activeSessions = b.NewUpDownCounter("gerbil_active_sessions", + if err != nil { + return err + } + activeSessions, err = newUpDownCounter("gerbil_active_sessions", "Number of active UDP relay sessions", "ifname") - activeProxyConnections = b.NewUpDownCounter("gerbil_active_proxy_connections", + if err != nil { + return err + } + activeProxyConnections, err = newUpDownCounter("gerbil_active_proxy_connections", "Active SNI proxy connections") - proxyRouteLookups = b.NewCounter("gerbil_proxy_route_lookups_total", + if err != nil { + return err + } + proxyRouteLookups, err = newCounter("gerbil_proxy_route_lookups_total", "Number of route lookups", "result") - proxyTLSHandshake = b.NewHistogram("gerbil_proxy_tls_handshake_seconds", + if err != nil { + return err + } + proxyTLSHandshake, err = newHistogram("gerbil_proxy_tls_handshake_seconds", "TLS handshake duration for SNI proxy in seconds", durationBuckets) - proxyBytesTransmitted = b.NewCounter("gerbil_proxy_bytes_transmitted_total", + if err != nil { + return err + } + proxyBytesTransmitted, err = newCounter("gerbil_proxy_bytes_transmitted_total", "Bytes sent/received by the SNI proxy", "direction") - configReloadsTotal = b.NewCounter("gerbil_config_reloads_total", + if err != nil { + return err + } + configReloadsTotal, err = newCounter("gerbil_config_reloads_total", "Number of configuration reloads", "result") - restartTotal = b.NewCounter("gerbil_restart_total", + if err != nil { + return err + } + restartTotal, err = newCounter("gerbil_restart_total", "Process restart count") - authFailuresTotal = b.NewCounter("gerbil_auth_failures_total", + if err != nil { + return err + } + authFailuresTotal, err = newCounter("gerbil_auth_failures_total", "Count of authentication or peer validation failures", "peer", "reason") - aclDeniedTotal = b.NewCounter("gerbil_acl_denied_total", + if err != nil { + return err + } + aclDeniedTotal, err = newCounter("gerbil_acl_denied_total", "Access control denied events", "ifname", "peer", "policy") - certificateExpiryDays = b.NewFloat64Gauge("gerbil_certificate_expiry_days", + if err != nil { + return err + } + certificateExpiryDays, err = newFloat64Gauge("gerbil_certificate_expiry_days", "Days until certificate expiry", "cert_name", "ifname") - udpPacketsTotal = b.NewCounter("gerbil_udp_packets_total", + if err != nil { + return err + } + udpPacketsTotal, err = newCounter("gerbil_udp_packets_total", "Count of UDP packets processed by relay workers", "ifname", "type", "direction") - udpPacketSizeBytes = b.NewHistogram("gerbil_udp_packet_size_bytes", + if err != nil { + return err + } + udpPacketSizeBytes, err = newHistogram("gerbil_udp_packet_size_bytes", "Size distribution of packets forwarded through relay", sizeBuckets, "ifname", "type") - holePunchEventsTotal = b.NewCounter("gerbil_hole_punch_events_total", + if err != nil { + return err + } + holePunchEventsTotal, err = newCounter("gerbil_hole_punch_events_total", "Count of hole punch messages processed", "ifname", "result") - proxyMappingActive = b.NewUpDownCounter("gerbil_proxy_mapping_active", + if err != nil { + return err + } + proxyMappingActive, err = newUpDownCounter("gerbil_proxy_mapping_active", "Number of active proxy mappings", "ifname") - sessionActive = b.NewUpDownCounter("gerbil_session_active", - "Number of active WireGuard sessions", "ifname") - sessionRebuiltTotal = b.NewCounter("gerbil_session_rebuilt_total", + if err != nil { + return err + } + sessionRebuiltTotal, err = newCounter("gerbil_session_rebuilt_total", "Count of sessions rebuilt from communication patterns", "ifname") - commPatternActive = b.NewUpDownCounter("gerbil_comm_pattern_active", + if err != nil { + return err + } + commPatternActive, err = newUpDownCounter("gerbil_comm_pattern_active", "Number of active communication patterns", "ifname") - proxyCleanupRemovedTotal = b.NewCounter("gerbil_proxy_cleanup_removed_total", + if err != nil { + return err + } + proxyCleanupRemovedTotal, err = newCounter("gerbil_proxy_cleanup_removed_total", "Count of items removed during cleanup routines", "ifname", "component") - proxyConnectionErrorsTotal = b.NewCounter("gerbil_proxy_connection_errors_total", + if err != nil { + return err + } + proxyConnectionErrorsTotal, err = newCounter("gerbil_proxy_connection_errors_total", "Count of connection errors in proxy operations", "ifname", "error_type") - proxyInitialMappingsTotal = b.NewInt64Gauge("gerbil_proxy_initial_mappings", + if err != nil { + return err + } + proxyInitialMappingsTotal, err = newInt64Gauge("gerbil_proxy_initial_mappings", "Number of initial proxy mappings loaded", "ifname") - proxyMappingUpdatesTotal = b.NewCounter("gerbil_proxy_mapping_updates_total", + if err != nil { + return err + } + proxyMappingUpdatesTotal, err = newCounter("gerbil_proxy_mapping_updates_total", "Count of proxy mapping updates", "ifname") - proxyIdleCleanupDuration = b.NewHistogram("gerbil_proxy_idle_cleanup_duration_seconds", + if err != nil { + return err + } + proxyIdleCleanupDuration, err = newHistogram("gerbil_proxy_idle_cleanup_duration_seconds", "Duration of cleanup cycles", durationBuckets, "ifname", "component") - sniConnectionsTotal = b.NewCounter("gerbil_sni_connections_total", + if err != nil { + return err + } + sniConnectionsTotal, err = newCounter("gerbil_sni_connections_total", "Count of connections processed by SNI proxy", "result") - sniConnectionDuration = b.NewHistogram("gerbil_sni_connection_duration_seconds", + if err != nil { + return err + } + sniConnectionDuration, err = newHistogram("gerbil_sni_connection_duration_seconds", "Lifetime distribution of proxied TLS connections", sniDurationBuckets) - sniActiveConnections = b.NewUpDownCounter("gerbil_sni_active_connections", + if err != nil { + return err + } + sniActiveConnections, err = newUpDownCounter("gerbil_sni_active_connections", "Number of active SNI tunnels") - sniRouteCacheHitsTotal = b.NewCounter("gerbil_sni_route_cache_hits_total", + if err != nil { + return err + } + sniRouteCacheHitsTotal, err = newCounter("gerbil_sni_route_cache_hits_total", "Count of route cache hits and misses", "result") - sniRouteAPIRequestsTotal = b.NewCounter("gerbil_sni_route_api_requests_total", + if err != nil { + return err + } + sniRouteAPIRequestsTotal, err = newCounter("gerbil_sni_route_api_requests_total", "Count of route API requests", "result") - sniRouteAPILatency = b.NewHistogram("gerbil_sni_route_api_latency_seconds", + if err != nil { + return err + } + sniRouteAPILatency, err = newHistogram("gerbil_sni_route_api_latency_seconds", "Distribution of route API call latencies", durationBuckets) - sniLocalOverrideTotal = b.NewCounter("gerbil_sni_local_override_total", + if err != nil { + return err + } + sniLocalOverrideTotal, err = newCounter("gerbil_sni_local_override_total", "Count of routes using local overrides", "hit") - sniTrustedProxyEventsTotal = b.NewCounter("gerbil_sni_trusted_proxy_events_total", + if err != nil { + return err + } + sniTrustedProxyEventsTotal, err = newCounter("gerbil_sni_trusted_proxy_events_total", "Count of PROXY protocol events", "event") - sniProxyProtocolParseErrorsTotal = b.NewCounter("gerbil_sni_proxy_protocol_parse_errors_total", + if err != nil { + return err + } + sniProxyProtocolParseErrorsTotal, err = newCounter("gerbil_sni_proxy_protocol_parse_errors_total", "Count of PROXY protocol parse failures") - sniDataBytesTotal = b.NewCounter("gerbil_sni_data_bytes_total", + if err != nil { + return err + } + sniDataBytesTotal, err = newCounter("gerbil_sni_data_bytes_total", "Count of bytes proxied through SNI tunnels", "direction") - sniTunnelTerminationsTotal = b.NewCounter("gerbil_sni_tunnel_terminations_total", + if err != nil { + return err + } + sniTunnelTerminationsTotal, err = newCounter("gerbil_sni_tunnel_terminations_total", "Count of tunnel terminations by reason", "reason") - httpRequestsTotal = b.NewCounter("gerbil_http_requests_total", + if err != nil { + return err + } + httpRequestsTotal, err = newCounter("gerbil_http_requests_total", "Count of HTTP requests to management API", "endpoint", "method", "status_code") - httpRequestDuration = b.NewHistogram("gerbil_http_request_duration_seconds", + if err != nil { + return err + } + httpRequestDuration, err = newHistogram("gerbil_http_request_duration_seconds", "Distribution of HTTP request handling time", durationBuckets, "endpoint", "method") - peerOperationsTotal = b.NewCounter("gerbil_peer_operations_total", + if err != nil { + return err + } + peerOperationsTotal, err = newCounter("gerbil_peer_operations_total", "Count of peer lifecycle operations", "operation", "result") - proxyMappingUpdateRequestsTotal = b.NewCounter("gerbil_proxy_mapping_update_requests_total", + if err != nil { + return err + } + proxyMappingUpdateRequestsTotal, err = newCounter("gerbil_proxy_mapping_update_requests_total", "Count of proxy mapping update API calls", "result") - destinationsUpdateRequestsTotal = b.NewCounter("gerbil_destinations_update_requests_total", + if err != nil { + return err + } + destinationsUpdateRequestsTotal, err = newCounter("gerbil_destinations_update_requests_total", "Count of destinations update API calls", "result") - remoteConfigFetchesTotal = b.NewCounter("gerbil_remote_config_fetches_total", + if err != nil { + return err + } + remoteConfigFetchesTotal, err = newCounter("gerbil_remote_config_fetches_total", "Count of remote configuration fetch attempts", "result") - bandwidthReportsTotal = b.NewCounter("gerbil_bandwidth_reports_total", + if err != nil { + return err + } + bandwidthReportsTotal, err = newCounter("gerbil_bandwidth_reports_total", "Count of bandwidth report transmissions", "result") - peerBandwidthBytesTotal = b.NewCounter("gerbil_peer_bandwidth_bytes_total", + if err != nil { + return err + } + peerBandwidthBytesTotal, err = newCounter("gerbil_peer_bandwidth_bytes_total", "Bytes per peer tracked by bandwidth calculation", "peer", "direction") - memorySpikeTotal = b.NewCounter("gerbil_memory_spike_total", + if err != nil { + return err + } + memorySpikeTotal, err = newCounter("gerbil_memory_spike_total", "Count of memory spikes detected", "severity") - heapProfilesWrittenTotal = b.NewCounter("gerbil_heap_profiles_written_total", + if err != nil { + return err + } + heapProfilesWrittenTotal, err = newCounter("gerbil_heap_profiles_written_total", "Count of heap profile files generated") + if err != nil { + return err + } return nil } func RecordInterfaceUp(ifname, instance string, up bool) { + if wgInterfaceUp == nil { + return + } value := int64(0) if up { value = 1 @@ -266,10 +498,16 @@ func RecordInterfaceUp(ifname, instance string, up bool) { } func RecordPeersTotal(ifname string, delta int64) { + if wgPeersTotal == nil { + return + } wgPeersTotal.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordPeerConnected(ifname, peer string, connected bool) { + if wgPeerConnected == nil { + return + } value := int64(0) if connected { value = 1 @@ -278,229 +516,393 @@ func RecordPeerConnected(ifname, peer string, connected bool) { } func RecordHandshake(ifname, peer, result string) { + if wgHandshakesTotal == nil { + return + } wgHandshakesTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "peer": peer, "result": result}) } func RecordHandshakeLatency(ifname, peer string, seconds float64) { + if wgHandshakeLatency == nil { + return + } wgHandshakeLatency.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordPeerRTT(ifname, peer string, seconds float64) { + if wgPeerRTT == nil { + return + } wgPeerRTT.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordBytesReceived(ifname, peer string, bytes int64) { + if wgBytesReceived == nil { + return + } wgBytesReceived.Add(context.Background(), bytes, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordBytesTransmitted(ifname, peer string, bytes int64) { + if wgBytesTransmitted == nil { + return + } wgBytesTransmitted.Add(context.Background(), bytes, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordAllowedIPsCount(ifname, peer string, delta int64) { + if allowedIPsCount == nil { + return + } allowedIPsCount.Add(context.Background(), delta, observability.Labels{"ifname": ifname, "peer": peer}) } func RecordKeyRotation(ifname, reason string) { + if keyRotationTotal == nil { + return + } keyRotationTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "reason": reason}) } func RecordNetlinkEvent(eventType string) { + if netlinkEventsTotal == nil { + return + } netlinkEventsTotal.Add(context.Background(), 1, observability.Labels{"event_type": eventType}) } func RecordNetlinkError(component, errorType string) { + if netlinkErrorsTotal == nil { + return + } netlinkErrorsTotal.Add(context.Background(), 1, observability.Labels{"component": component, "error_type": errorType}) } func RecordSyncDuration(component string, seconds float64) { + if syncDuration == nil { + return + } syncDuration.Record(context.Background(), seconds, observability.Labels{"component": component}) } func RecordWorkqueueDepth(queue string, delta int64) { + if workqueueDepth == nil { + return + } workqueueDepth.Add(context.Background(), delta, observability.Labels{"queue": queue}) } func RecordKernelModuleLoad(result string) { + if kernelModuleLoads == nil { + return + } kernelModuleLoads.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordFirewallRuleApplied(result, chain string) { + if firewallRulesApplied == nil { + return + } firewallRulesApplied.Add(context.Background(), 1, observability.Labels{"result": result, "chain": chain}) } func RecordActiveSession(ifname string, delta int64) { + if activeSessions == nil { + return + } activeSessions.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } -func RecordActiveProxyConnection(hostname string, delta int64) { - _ = hostname +func RecordActiveProxyConnection(delta int64) { + if activeProxyConnections == nil { + return + } activeProxyConnections.Add(context.Background(), delta, nil) } -func RecordProxyRouteLookup(result, hostname string) { - _ = hostname +func RecordProxyRouteLookup(result string) { + if proxyRouteLookups == nil { + return + } proxyRouteLookups.Add(context.Background(), 1, observability.Labels{"result": result}) } -func RecordProxyTLSHandshake(hostname string, seconds float64) { - _ = hostname +func RecordProxyTLSHandshake(seconds float64) { + if proxyTLSHandshake == nil { + return + } proxyTLSHandshake.Record(context.Background(), seconds, nil) } -func RecordProxyBytesTransmitted(hostname, direction string, bytes int64) { - _ = hostname +func RecordProxyBytesTransmitted(direction string, bytes int64) { + if proxyBytesTransmitted == nil { + return + } proxyBytesTransmitted.Add(context.Background(), bytes, observability.Labels{"direction": direction}) } func RecordConfigReload(result string) { + if configReloadsTotal == nil { + return + } configReloadsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordRestart() { + if restartTotal == nil { + return + } restartTotal.Add(context.Background(), 1, nil) } func RecordAuthFailure(peer, reason string) { + if authFailuresTotal == nil { + return + } authFailuresTotal.Add(context.Background(), 1, observability.Labels{"peer": peer, "reason": reason}) } func RecordACLDenied(ifname, peer, policy string) { + if aclDeniedTotal == nil { + return + } aclDeniedTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "peer": peer, "policy": policy}) } func RecordCertificateExpiry(certName, ifname string, days float64) { + if certificateExpiryDays == nil { + return + } certificateExpiryDays.Record(context.Background(), days, observability.Labels{"cert_name": certName, "ifname": ifname}) } func RecordUDPPacket(ifname, packetType, direction string) { + if udpPacketsTotal == nil { + return + } udpPacketsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "type": packetType, "direction": direction}) } func RecordUDPPacketSize(ifname, packetType string, bytes float64) { + if udpPacketSizeBytes == nil { + return + } udpPacketSizeBytes.Record(context.Background(), bytes, observability.Labels{"ifname": ifname, "type": packetType}) } func RecordHolePunchEvent(ifname, result string) { + if holePunchEventsTotal == nil { + return + } holePunchEventsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "result": result}) } func RecordProxyMapping(ifname string, delta int64) { + if proxyMappingActive == nil { + return + } proxyMappingActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordSession(ifname string, delta int64) { - sessionActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) + if activeSessions == nil { + return + } + activeSessions.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordSessionRebuilt(ifname string) { + if sessionRebuiltTotal == nil { + return + } sessionRebuiltTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname}) } func RecordCommPattern(ifname string, delta int64) { + if commPatternActive == nil { + return + } commPatternActive.Add(context.Background(), delta, observability.Labels{"ifname": ifname}) } func RecordProxyCleanupRemoved(ifname, component string, count int64) { + if proxyCleanupRemovedTotal == nil { + return + } proxyCleanupRemovedTotal.Add(context.Background(), count, observability.Labels{"ifname": ifname, "component": component}) } func RecordProxyConnectionError(ifname, errorType string) { + if proxyConnectionErrorsTotal == nil { + return + } proxyConnectionErrorsTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname, "error_type": errorType}) } func RecordProxyInitialMappings(ifname string, count int64) { + if proxyInitialMappingsTotal == nil { + return + } proxyInitialMappingsTotal.Record(context.Background(), count, observability.Labels{"ifname": ifname}) } func RecordProxyMappingUpdate(ifname string) { + if proxyMappingUpdatesTotal == nil { + return + } proxyMappingUpdatesTotal.Add(context.Background(), 1, observability.Labels{"ifname": ifname}) } func RecordProxyIdleCleanupDuration(ifname, component string, seconds float64) { + if proxyIdleCleanupDuration == nil { + return + } proxyIdleCleanupDuration.Record(context.Background(), seconds, observability.Labels{"ifname": ifname, "component": component}) } func RecordSNIConnection(result string) { + if sniConnectionsTotal == nil { + return + } sniConnectionsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordSNIConnectionDuration(seconds float64) { + if sniConnectionDuration == nil { + return + } sniConnectionDuration.Record(context.Background(), seconds, nil) } func RecordSNIActiveConnection(delta int64) { + if sniActiveConnections == nil { + return + } sniActiveConnections.Add(context.Background(), delta, nil) } func RecordSNIRouteCacheHit(result string) { + if sniRouteCacheHitsTotal == nil { + return + } sniRouteCacheHitsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordSNIRouteAPIRequest(result string) { + if sniRouteAPIRequestsTotal == nil { + return + } sniRouteAPIRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordSNIRouteAPILatency(seconds float64) { + if sniRouteAPILatency == nil { + return + } sniRouteAPILatency.Record(context.Background(), seconds, nil) } func RecordSNILocalOverride(hit string) { + if sniLocalOverrideTotal == nil { + return + } sniLocalOverrideTotal.Add(context.Background(), 1, observability.Labels{"hit": hit}) } func RecordSNITrustedProxyEvent(event string) { + if sniTrustedProxyEventsTotal == nil { + return + } sniTrustedProxyEventsTotal.Add(context.Background(), 1, observability.Labels{"event": event}) } func RecordSNIProxyProtocolParseError() { + if sniProxyProtocolParseErrorsTotal == nil { + return + } sniProxyProtocolParseErrorsTotal.Add(context.Background(), 1, nil) } func RecordSNIDataBytes(direction string, bytes int64) { + if sniDataBytesTotal == nil { + return + } sniDataBytesTotal.Add(context.Background(), bytes, observability.Labels{"direction": direction}) } func RecordSNITunnelTermination(reason string) { + if sniTunnelTerminationsTotal == nil { + return + } sniTunnelTerminationsTotal.Add(context.Background(), 1, observability.Labels{"reason": reason}) } func RecordHTTPRequest(endpoint, method, statusCode string) { + if httpRequestsTotal == nil { + return + } httpRequestsTotal.Add(context.Background(), 1, observability.Labels{"endpoint": endpoint, "method": method, "status_code": statusCode}) } func RecordHTTPRequestDuration(endpoint, method string, seconds float64) { + if httpRequestDuration == nil { + return + } httpRequestDuration.Record(context.Background(), seconds, observability.Labels{"endpoint": endpoint, "method": method}) } func RecordPeerOperation(operation, result string) { + if peerOperationsTotal == nil { + return + } peerOperationsTotal.Add(context.Background(), 1, observability.Labels{"operation": operation, "result": result}) } func RecordProxyMappingUpdateRequest(result string) { + if proxyMappingUpdateRequestsTotal == nil { + return + } proxyMappingUpdateRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordDestinationsUpdateRequest(result string) { + if destinationsUpdateRequestsTotal == nil { + return + } destinationsUpdateRequestsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordRemoteConfigFetch(result string) { + if remoteConfigFetchesTotal == nil { + return + } remoteConfigFetchesTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordBandwidthReport(result string) { + if bandwidthReportsTotal == nil { + return + } bandwidthReportsTotal.Add(context.Background(), 1, observability.Labels{"result": result}) } func RecordPeerBandwidthBytes(peer, direction string, bytes int64) { + if peerBandwidthBytesTotal == nil { + return + } peerBandwidthBytesTotal.Add(context.Background(), bytes, observability.Labels{"peer": peer, "direction": direction}) } func RecordMemorySpike(severity string) { + if memorySpikeTotal == nil { + return + } memorySpikeTotal.Add(context.Background(), 1, observability.Labels{"severity": severity}) } func RecordHeapProfileWritten() { + if heapProfilesWrittenTotal == nil { + return + } heapProfilesWrittenTotal.Add(context.Background(), 1, nil) } diff --git a/internal/metrics/metrics_test.go b/internal/metrics/metrics_test.go index 132c3fe..8c01c68 100644 --- a/internal/metrics/metrics_test.go +++ b/internal/metrics/metrics_test.go @@ -89,6 +89,9 @@ func TestDefaultConfig(t *testing.T) { } func TestShutdownNoInit(t *testing.T) { + // Ensure a known clean global state before testing no-init shutdown behavior. + _ = metrics.Shutdown(context.Background()) + // Shutdown without Initialize should not panic or error. if err := metrics.Shutdown(context.Background()); err != nil { t.Errorf("unexpected error: %v", err) @@ -168,6 +171,7 @@ func TestRecordRelay(t *testing.T) { body := scrape(t, h) assertContains(t, body, "gerbil_udp_packets_total") assertContains(t, body, "gerbil_proxy_mapping_active") + assertContains(t, body, "gerbil_active_sessions") } func TestRecordWireGuard(t *testing.T) { @@ -216,10 +220,10 @@ func TestRecordNetlink(t *testing.T) { metrics.RecordKernelModuleLoad("success") metrics.RecordFirewallRuleApplied("success", "INPUT") metrics.RecordActiveSession("wg0", 1) - metrics.RecordActiveProxyConnection(exampleHostname, 1) - metrics.RecordProxyRouteLookup("hit", exampleHostname) - metrics.RecordProxyTLSHandshake(exampleHostname, 0.05) - metrics.RecordProxyBytesTransmitted(exampleHostname, "tx", 1024) + metrics.RecordActiveProxyConnection(1) + metrics.RecordProxyRouteLookup("hit") + metrics.RecordProxyTLSHandshake(0.05) + metrics.RecordProxyBytesTransmitted("tx", 1024) body := scrape(t, h) assertContains(t, body, "gerbil_netlink_events_total") assertContains(t, body, "gerbil_active_sessions") From 73d4d4d37cbc8a108dc66a10cc99f158d1b33ca5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:12:23 +0200 Subject: [PATCH 21/25] feat(observability): unify backend APIs and harden OTel handling --- internal/observability/config.go | 12 +- internal/observability/metrics.go | 31 ++-- internal/observability/metrics_test.go | 99 ++++++++++-- internal/observability/noop.go | 27 ++-- internal/observability/noop_test.go | 67 ++++++-- internal/observability/otel/backend.go | 163 ++++++++++++++++---- internal/observability/otel/backend_test.go | 44 +++++- internal/observability/otel/exporter.go | 22 ++- 8 files changed, 360 insertions(+), 105 deletions(-) diff --git a/internal/observability/config.go b/internal/observability/config.go index 9643727..7ac9ba2 100644 --- a/internal/observability/config.go +++ b/internal/observability/config.go @@ -60,6 +60,10 @@ type OTelConfig struct { // ExportInterval is how often metrics are pushed to the collector. // Defaults to 60 s. ExportInterval time.Duration + + // Timeout bounds OTLP exporter construction calls. + // Defaults to 10 s. + Timeout time.Duration } // DefaultMetricsConfig returns a MetricsConfig with sensible defaults. @@ -75,6 +79,7 @@ func DefaultMetricsConfig() MetricsConfig { Endpoint: "localhost:4317", Insecure: true, ExportInterval: 60 * time.Second, + Timeout: 10 * time.Second, }, ServiceName: "gerbil", ServiceVersion: "1.0.0", @@ -88,8 +93,10 @@ func (c *MetricsConfig) Validate() error { } switch c.Backend { - case "prometheus", "none", "": + case "prometheus", "none": // valid + case "": + return fmt.Errorf("metrics: enabled requires a non-empty backend") case "otel": if c.OTel.Endpoint == "" { return fmt.Errorf("metrics: backend=otel requires a non-empty OTel endpoint") @@ -100,6 +107,9 @@ func (c *MetricsConfig) Validate() error { if c.OTel.ExportInterval <= 0 { return fmt.Errorf("metrics: otel export interval must be positive") } + if c.OTel.Timeout <= 0 { + return fmt.Errorf("metrics: otel timeout must be positive") + } default: return fmt.Errorf("metrics: unknown backend %q (must be \"prometheus\", \"otel\", or \"none\")", c.Backend) } diff --git a/internal/observability/metrics.go b/internal/observability/metrics.go index ff2e2ae..9ea0013 100644 --- a/internal/observability/metrics.go +++ b/internal/observability/metrics.go @@ -43,20 +43,20 @@ type Histogram interface { type Backend interface { // NewCounter creates a counter metric. // labelNames declares the set of label keys that will be passed at observation time. - NewCounter(name, desc string, labelNames ...string) Counter + NewCounter(name, desc string, labelNames ...string) (Counter, error) // NewUpDownCounter creates an up-down counter metric. - NewUpDownCounter(name, desc string, labelNames ...string) UpDownCounter + NewUpDownCounter(name, desc string, labelNames ...string) (UpDownCounter, error) // NewInt64Gauge creates an integer gauge metric. - NewInt64Gauge(name, desc string, labelNames ...string) Int64Gauge + NewInt64Gauge(name, desc string, labelNames ...string) (Int64Gauge, error) // NewFloat64Gauge creates a float gauge metric. - NewFloat64Gauge(name, desc string, labelNames ...string) Float64Gauge + NewFloat64Gauge(name, desc string, labelNames ...string) (Float64Gauge, error) // NewHistogram creates a histogram metric. // buckets are the explicit upper-bound bucket boundaries. - NewHistogram(name, desc string, buckets []float64, labelNames ...string) Histogram + NewHistogram(name, desc string, buckets []float64, labelNames ...string) (Histogram, error) // HTTPHandler returns the /metrics HTTP handler. // Implementations that do not expose an HTTP endpoint return nil. @@ -88,6 +88,7 @@ func New(cfg MetricsConfig) (Backend, error) { Endpoint: cfg.OTel.Endpoint, Insecure: cfg.OTel.Insecure, ExportInterval: cfg.OTel.ExportInterval, + Timeout: cfg.OTel.Timeout, ServiceName: cfg.ServiceName, ServiceVersion: cfg.ServiceVersion, DeploymentEnvironment: cfg.DeploymentEnvironment, @@ -110,19 +111,19 @@ type promAdapter struct { b *obsprom.Backend } -func (a *promAdapter) NewCounter(name, desc string, labelNames ...string) Counter { +func (a *promAdapter) NewCounter(name, desc string, labelNames ...string) (Counter, error) { return a.b.NewCounter(name, desc, labelNames...) } -func (a *promAdapter) NewUpDownCounter(name, desc string, labelNames ...string) UpDownCounter { +func (a *promAdapter) NewUpDownCounter(name, desc string, labelNames ...string) (UpDownCounter, error) { return a.b.NewUpDownCounter(name, desc, labelNames...) } -func (a *promAdapter) NewInt64Gauge(name, desc string, labelNames ...string) Int64Gauge { +func (a *promAdapter) NewInt64Gauge(name, desc string, labelNames ...string) (Int64Gauge, error) { return a.b.NewInt64Gauge(name, desc, labelNames...) } -func (a *promAdapter) NewFloat64Gauge(name, desc string, labelNames ...string) Float64Gauge { +func (a *promAdapter) NewFloat64Gauge(name, desc string, labelNames ...string) (Float64Gauge, error) { return a.b.NewFloat64Gauge(name, desc, labelNames...) } -func (a *promAdapter) NewHistogram(name, desc string, buckets []float64, labelNames ...string) Histogram { +func (a *promAdapter) NewHistogram(name, desc string, buckets []float64, labelNames ...string) (Histogram, error) { return a.b.NewHistogram(name, desc, buckets, labelNames...) } func (a *promAdapter) HTTPHandler() http.Handler { return a.b.HTTPHandler() } @@ -133,19 +134,19 @@ type otelAdapter struct { b *obsotel.Backend } -func (a *otelAdapter) NewCounter(name, desc string, labelNames ...string) Counter { +func (a *otelAdapter) NewCounter(name, desc string, labelNames ...string) (Counter, error) { return a.b.NewCounter(name, desc, labelNames...) } -func (a *otelAdapter) NewUpDownCounter(name, desc string, labelNames ...string) UpDownCounter { +func (a *otelAdapter) NewUpDownCounter(name, desc string, labelNames ...string) (UpDownCounter, error) { return a.b.NewUpDownCounter(name, desc, labelNames...) } -func (a *otelAdapter) NewInt64Gauge(name, desc string, labelNames ...string) Int64Gauge { +func (a *otelAdapter) NewInt64Gauge(name, desc string, labelNames ...string) (Int64Gauge, error) { return a.b.NewInt64Gauge(name, desc, labelNames...) } -func (a *otelAdapter) NewFloat64Gauge(name, desc string, labelNames ...string) Float64Gauge { +func (a *otelAdapter) NewFloat64Gauge(name, desc string, labelNames ...string) (Float64Gauge, error) { return a.b.NewFloat64Gauge(name, desc, labelNames...) } -func (a *otelAdapter) NewHistogram(name, desc string, buckets []float64, labelNames ...string) Histogram { +func (a *otelAdapter) NewHistogram(name, desc string, buckets []float64, labelNames ...string) (Histogram, error) { return a.b.NewHistogram(name, desc, buckets, labelNames...) } func (a *otelAdapter) HTTPHandler() http.Handler { return a.b.HTTPHandler() } diff --git a/internal/observability/metrics_test.go b/internal/observability/metrics_test.go index 91a048c..cf3eccb 100644 --- a/internal/observability/metrics_test.go +++ b/internal/observability/metrics_test.go @@ -2,6 +2,8 @@ package observability_test import ( "context" + "net" + "os" "testing" "time" @@ -39,20 +41,19 @@ func TestValidateValidConfigs(t *testing.T) { }{ {name: "disabled", cfg: observability.MetricsConfig{Enabled: false}}, {name: "backend none", cfg: observability.MetricsConfig{Enabled: true, Backend: "none"}}, - {name: "backend empty", cfg: observability.MetricsConfig{Enabled: true, Backend: ""}}, {name: "prometheus", cfg: observability.MetricsConfig{Enabled: true, Backend: "prometheus"}}, { name: "otel grpc", cfg: observability.MetricsConfig{ Enabled: true, Backend: "otel", - OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, ExportInterval: 10 * time.Second}, + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, ExportInterval: 10 * time.Second, Timeout: 2 * time.Second}, }, }, { name: "otel http", cfg: observability.MetricsConfig{ Enabled: true, Backend: "otel", - OTel: observability.OTelConfig{Protocol: "http", Endpoint: "localhost:4318", ExportInterval: 30 * time.Second}, + OTel: observability.OTelConfig{Protocol: "http", Endpoint: "localhost:4318", ExportInterval: 30 * time.Second, Timeout: 2 * time.Second}, }, }, } @@ -71,25 +72,36 @@ func TestValidateInvalidConfigs(t *testing.T) { cfg observability.MetricsConfig }{ {name: "unknown backend", cfg: observability.MetricsConfig{Enabled: true, Backend: "datadog"}}, + { + name: "backend empty while enabled", + cfg: observability.MetricsConfig{Enabled: true, Backend: ""}, + }, { name: "otel missing endpoint", cfg: observability.MetricsConfig{ Enabled: true, Backend: "otel", - OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: "", ExportInterval: 10 * time.Second}, + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: "", ExportInterval: 10 * time.Second, Timeout: 2 * time.Second}, }, }, { name: "otel invalid protocol", cfg: observability.MetricsConfig{ Enabled: true, Backend: "otel", - OTel: observability.OTelConfig{Protocol: "tcp", Endpoint: otelGRPCEndpoint, ExportInterval: 10 * time.Second}, + OTel: observability.OTelConfig{Protocol: "tcp", Endpoint: otelGRPCEndpoint, ExportInterval: 10 * time.Second, Timeout: 2 * time.Second}, }, }, { name: "otel zero interval", cfg: observability.MetricsConfig{ Enabled: true, Backend: "otel", - OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, ExportInterval: 0}, + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, ExportInterval: 0, Timeout: 2 * time.Second}, + }, + }, + { + name: "otel zero timeout", + cfg: observability.MetricsConfig{ + Enabled: true, Backend: "otel", + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, ExportInterval: 10 * time.Second, Timeout: 0}, }, }, } @@ -157,11 +169,32 @@ func TestPrometheusAdapterAllInstruments(t *testing.T) { ctx := context.Background() labels := observability.Labels{"k": "v"} - b.NewCounter("prom_adapter_counter_total", "desc", "k").Add(ctx, 1, labels) - b.NewUpDownCounter("prom_adapter_updown", "desc", "k").Add(ctx, 2, labels) - b.NewInt64Gauge("prom_adapter_int_gauge", "desc", "k").Record(ctx, 99, labels) - b.NewFloat64Gauge("prom_adapter_float_gauge", "desc", "k").Record(ctx, 1.23, labels) - b.NewHistogram("prom_adapter_histogram", "desc", []float64{0.1, 1.0}, "k").Record(ctx, 0.5, labels) + c, err := b.NewCounter("prom_adapter_counter_total", "desc", "k") + if err != nil { + t.Fatalf("NewCounter error: %v", err) + } + u, err := b.NewUpDownCounter("prom_adapter_updown", "desc", "k") + if err != nil { + t.Fatalf("NewUpDownCounter error: %v", err) + } + ig, err := b.NewInt64Gauge("prom_adapter_int_gauge", "desc", "k") + if err != nil { + t.Fatalf("NewInt64Gauge error: %v", err) + } + fg, err := b.NewFloat64Gauge("prom_adapter_float_gauge", "desc", "k") + if err != nil { + t.Fatalf("NewFloat64Gauge error: %v", err) + } + h, err := b.NewHistogram("prom_adapter_histogram", "desc", []float64{0.1, 1.0}, "k") + if err != nil { + t.Fatalf("NewHistogram error: %v", err) + } + + c.Add(ctx, 1, labels) + u.Add(ctx, 2, labels) + ig.Record(ctx, 99, labels) + fg.Record(ctx, 1.23, labels) + h.Record(ctx, 0.5, labels) if b.HTTPHandler() == nil { t.Error("prometheus adapter HTTPHandler should not be nil") @@ -172,9 +205,20 @@ func TestPrometheusAdapterAllInstruments(t *testing.T) { } func TestOtelAdapterAllInstruments(t *testing.T) { + if os.Getenv("SKIP_OTEL_INTEGRATION") != "" { + t.Skip("skipping OTel integration test because SKIP_OTEL_INTEGRATION is set") + } + + dialTimeout := 300 * time.Millisecond + conn, err := net.DialTimeout("tcp", otelGRPCEndpoint, dialTimeout) + if err != nil { + t.Skipf("skipping OTel integration test; collector %s not reachable: %v", otelGRPCEndpoint, err) + } + _ = conn.Close() + b, err := observability.New(observability.MetricsConfig{ Enabled: true, Backend: "otel", - OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, Insecure: true, ExportInterval: 100 * time.Millisecond}, + OTel: observability.OTelConfig{Protocol: "grpc", Endpoint: otelGRPCEndpoint, Insecure: true, ExportInterval: 100 * time.Millisecond, Timeout: 2 * time.Second}, }) if err != nil { t.Fatalf("failed to create otel backend: %v", err) @@ -182,11 +226,32 @@ func TestOtelAdapterAllInstruments(t *testing.T) { ctx := context.Background() labels := observability.Labels{"k": "v"} - b.NewCounter("otel_adapter_counter_total", "desc", "k").Add(ctx, 1, labels) - b.NewUpDownCounter("otel_adapter_updown", "desc", "k").Add(ctx, 2, labels) - b.NewInt64Gauge("otel_adapter_int_gauge", "desc", "k").Record(ctx, 99, labels) - b.NewFloat64Gauge("otel_adapter_float_gauge", "desc", "k").Record(ctx, 1.23, labels) - b.NewHistogram("otel_adapter_histogram", "desc", []float64{0.1, 1.0}, "k").Record(ctx, 0.5, labels) + c, err := b.NewCounter("otel_adapter_counter_total", "desc", "k") + if err != nil { + t.Fatalf("NewCounter error: %v", err) + } + u, err := b.NewUpDownCounter("otel_adapter_updown", "desc", "k") + if err != nil { + t.Fatalf("NewUpDownCounter error: %v", err) + } + ig, err := b.NewInt64Gauge("otel_adapter_int_gauge", "desc", "k") + if err != nil { + t.Fatalf("NewInt64Gauge error: %v", err) + } + fg, err := b.NewFloat64Gauge("otel_adapter_float_gauge", "desc", "k") + if err != nil { + t.Fatalf("NewFloat64Gauge error: %v", err) + } + h, err := b.NewHistogram("otel_adapter_histogram", "desc", []float64{0.1, 1.0}, "k") + if err != nil { + t.Fatalf("NewHistogram error: %v", err) + } + + c.Add(ctx, 1, labels) + u.Add(ctx, 2, labels) + ig.Record(ctx, 99, labels) + fg.Record(ctx, 1.23, labels) + h.Record(ctx, 0.5, labels) if b.HTTPHandler() != nil { t.Error("OTel adapter HTTPHandler should be nil") diff --git a/internal/observability/noop.go b/internal/observability/noop.go index 47acc07..cfd8280 100644 --- a/internal/observability/noop.go +++ b/internal/observability/noop.go @@ -13,38 +13,31 @@ type NoopBackend struct{} // Compile-time interface check. var _ Backend = (*NoopBackend)(nil) -func (n *NoopBackend) NewCounter(_ string, _ string, _ ...string) Counter { - _ = n - return noopCounter{} +func (n *NoopBackend) NewCounter(_ string, _ string, _ ...string) (Counter, error) { + return noopCounter{}, nil } -func (n *NoopBackend) NewUpDownCounter(_ string, _ string, _ ...string) UpDownCounter { - _ = n - return noopUpDownCounter{} +func (n *NoopBackend) NewUpDownCounter(_ string, _ string, _ ...string) (UpDownCounter, error) { + return noopUpDownCounter{}, nil } -func (n *NoopBackend) NewInt64Gauge(_ string, _ string, _ ...string) Int64Gauge { - _ = n - return noopInt64Gauge{} +func (n *NoopBackend) NewInt64Gauge(_ string, _ string, _ ...string) (Int64Gauge, error) { + return noopInt64Gauge{}, nil } -func (n *NoopBackend) NewFloat64Gauge(_ string, _ string, _ ...string) Float64Gauge { - _ = n - return noopFloat64Gauge{} +func (n *NoopBackend) NewFloat64Gauge(_ string, _ string, _ ...string) (Float64Gauge, error) { + return noopFloat64Gauge{}, nil } -func (n *NoopBackend) NewHistogram(_ string, _ string, _ []float64, _ ...string) Histogram { - _ = n - return noopHistogram{} +func (n *NoopBackend) NewHistogram(_ string, _ string, _ []float64, _ ...string) (Histogram, error) { + return noopHistogram{}, nil } func (n *NoopBackend) HTTPHandler() http.Handler { - _ = n return nil } func (n *NoopBackend) Shutdown(_ context.Context) error { - _ = n return nil } diff --git a/internal/observability/noop_test.go b/internal/observability/noop_test.go index 9496a0a..037ceb2 100644 --- a/internal/observability/noop_test.go +++ b/internal/observability/noop_test.go @@ -13,32 +13,32 @@ func TestNoopBackendAllInstruments(t *testing.T) { ctx := context.Background() labels := observability.Labels{"k": "v"} - t.Run("Counter", func(_ *testing.T) { - c := n.NewCounter("test_counter", "desc") + t.Run("Counter", func(t *testing.T) { + c, _ := n.NewCounter("test_counter", "desc") c.Add(ctx, 1, labels) c.Add(ctx, 0, nil) }) - t.Run("UpDownCounter", func(_ *testing.T) { - u := n.NewUpDownCounter("test_updown", "desc") + t.Run("UpDownCounter", func(t *testing.T) { + u, _ := n.NewUpDownCounter("test_updown", "desc") u.Add(ctx, 1, labels) u.Add(ctx, -1, nil) }) - t.Run("Int64Gauge", func(_ *testing.T) { - g := n.NewInt64Gauge("test_int64gauge", "desc") + t.Run("Int64Gauge", func(t *testing.T) { + g, _ := n.NewInt64Gauge("test_int64gauge", "desc") g.Record(ctx, 42, labels) g.Record(ctx, 0, nil) }) - t.Run("Float64Gauge", func(_ *testing.T) { - g := n.NewFloat64Gauge("test_float64gauge", "desc") + t.Run("Float64Gauge", func(t *testing.T) { + g, _ := n.NewFloat64Gauge("test_float64gauge", "desc") g.Record(ctx, 3.14, labels) g.Record(ctx, 0, nil) }) - t.Run("Histogram", func(_ *testing.T) { - h := n.NewHistogram("test_histogram", "desc", []float64{1, 5, 10}) + t.Run("Histogram", func(t *testing.T) { + h, _ := n.NewHistogram("test_histogram", "desc", []float64{1, 5, 10}) h.Record(ctx, 2.5, labels) h.Record(ctx, 0, nil) }) @@ -56,12 +56,47 @@ func TestNoopBackendAllInstruments(t *testing.T) { }) } -func TestNoopBackendLabelNames(_ *testing.T) { +func TestNoopBackendLabelNames(t *testing.T) { // Verify that label names passed at creation time are accepted without panic. n := &observability.NoopBackend{} - n.NewCounter("c", "d", "label1", "label2") - n.NewUpDownCounter("u", "d", "l1") - n.NewInt64Gauge("g1", "d", "l1", "l2", "l3") - n.NewFloat64Gauge("g2", "d") - n.NewHistogram("h", "d", []float64{0.1, 1.0}, "l1") + + assertNoPanic := func(t *testing.T, constructor string, fn func()) { + t.Helper() + defer func() { + if r := recover(); r != nil { + t.Fatalf("%s panicked: %v", constructor, r) + } + }() + fn() + } + + t.Run("NewCounter", func(t *testing.T) { + assertNoPanic(t, "NewCounter", func() { + _, _ = n.NewCounter("c", "d", "label1", "label2") + }) + }) + + t.Run("NewUpDownCounter", func(t *testing.T) { + assertNoPanic(t, "NewUpDownCounter", func() { + _, _ = n.NewUpDownCounter("u", "d", "l1") + }) + }) + + t.Run("NewInt64Gauge", func(t *testing.T) { + assertNoPanic(t, "NewInt64Gauge", func() { + _, _ = n.NewInt64Gauge("g1", "d", "l1", "l2", "l3") + }) + }) + + t.Run("NewFloat64Gauge", func(t *testing.T) { + assertNoPanic(t, "NewFloat64Gauge", func() { + _, _ = n.NewFloat64Gauge("g2", "d") + }) + }) + + t.Run("NewHistogram", func(t *testing.T) { + assertNoPanic(t, "NewHistogram", func() { + _, _ = n.NewHistogram("h", "d", []float64{0.1, 1.0}, "l1") + }) + }) } diff --git a/internal/observability/otel/backend.go b/internal/observability/otel/backend.go index d3e3a23..7f49579 100644 --- a/internal/observability/otel/backend.go +++ b/internal/observability/otel/backend.go @@ -9,7 +9,10 @@ package otel import ( "context" "fmt" + "log" "net/http" + "regexp" + "strings" "time" "go.opentelemetry.io/otel/attribute" @@ -17,6 +20,8 @@ import ( sdkmetric "go.opentelemetry.io/otel/sdk/metric" ) +var metricLabelNameRE = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`) + // Config holds OTel backend configuration. type Config struct { // Protocol is "grpc" (default) or "http". @@ -31,6 +36,9 @@ type Config struct { // ExportInterval is the period between pushes to the collector. ExportInterval time.Duration + // Timeout bounds exporter construction calls. + Timeout time.Duration + ServiceName string ServiceVersion string DeploymentEnvironment string @@ -57,9 +65,15 @@ func New(cfg Config) (*Backend, error) { if cfg.Protocol == "" { cfg.Protocol = "grpc" } + if strings.TrimSpace(cfg.Endpoint) == "" { + return nil, fmt.Errorf("otel backend: empty cfg.Endpoint") + } if cfg.ExportInterval <= 0 { cfg.ExportInterval = 60 * time.Second } + if cfg.Timeout <= 0 { + cfg.Timeout = 10 * time.Second + } if cfg.ServiceName == "" { cfg.ServiceName = "gerbil" } @@ -100,111 +114,196 @@ func (b *Backend) Shutdown(ctx context.Context) error { } // NewCounter creates an OTel Int64Counter. -func (b *Backend) NewCounter(name, desc string, _ ...string) *Counter { +func (b *Backend) NewCounter(name, desc string, labelNames ...string) (*Counter, error) { + normalizedLabelNames, err := validateLabelNames(labelNames) + if err != nil { + return nil, fmt.Errorf("otel: create counter %q: %w", name, err) + } c, err := b.meter.Int64Counter(name, metric.WithDescription(desc)) if err != nil { - panic(fmt.Sprintf("otel: create counter %q: %v", name, err)) + return nil, fmt.Errorf("otel: create counter %q: %w", name, err) } - return &Counter{c: c} + return &Counter{c: c, labelNames: normalizedLabelNames}, nil } // NewUpDownCounter creates an OTel Int64UpDownCounter. -func (b *Backend) NewUpDownCounter(name, desc string, _ ...string) *UpDownCounter { +func (b *Backend) NewUpDownCounter(name, desc string, labelNames ...string) (*UpDownCounter, error) { + normalizedLabelNames, err := validateLabelNames(labelNames) + if err != nil { + return nil, fmt.Errorf("otel: create up-down counter %q: %w", name, err) + } c, err := b.meter.Int64UpDownCounter(name, metric.WithDescription(desc)) if err != nil { - panic(fmt.Sprintf("otel: create up-down counter %q: %v", name, err)) + return nil, fmt.Errorf("otel: create up-down counter %q: %w", name, err) } - return &UpDownCounter{c: c} + return &UpDownCounter{c: c, labelNames: normalizedLabelNames}, nil } // NewInt64Gauge creates an OTel Int64Gauge. -func (b *Backend) NewInt64Gauge(name, desc string, _ ...string) *Int64Gauge { +func (b *Backend) NewInt64Gauge(name, desc string, labelNames ...string) (*Int64Gauge, error) { + normalizedLabelNames, err := validateLabelNames(labelNames) + if err != nil { + return nil, fmt.Errorf("otel: create int64 gauge %q: %w", name, err) + } g, err := b.meter.Int64Gauge(name, metric.WithDescription(desc)) if err != nil { - panic(fmt.Sprintf("otel: create int64 gauge %q: %v", name, err)) + return nil, fmt.Errorf("otel: create int64 gauge %q: %w", name, err) } - return &Int64Gauge{g: g} + return &Int64Gauge{g: g, labelNames: normalizedLabelNames}, nil } // NewFloat64Gauge creates an OTel Float64Gauge. -func (b *Backend) NewFloat64Gauge(name, desc string, _ ...string) *Float64Gauge { +func (b *Backend) NewFloat64Gauge(name, desc string, labelNames ...string) (*Float64Gauge, error) { + normalizedLabelNames, err := validateLabelNames(labelNames) + if err != nil { + return nil, fmt.Errorf("otel: create float64 gauge %q: %w", name, err) + } g, err := b.meter.Float64Gauge(name, metric.WithDescription(desc)) if err != nil { - panic(fmt.Sprintf("otel: create float64 gauge %q: %v", name, err)) + return nil, fmt.Errorf("otel: create float64 gauge %q: %w", name, err) } - return &Float64Gauge{g: g} + return &Float64Gauge{g: g, labelNames: normalizedLabelNames}, nil } // NewHistogram creates an OTel Float64Histogram with explicit bucket boundaries. -func (b *Backend) NewHistogram(name, desc string, buckets []float64, _ ...string) *Histogram { +func (b *Backend) NewHistogram(name, desc string, buckets []float64, labelNames ...string) (*Histogram, error) { + normalizedLabelNames, err := validateLabelNames(labelNames) + if err != nil { + return nil, fmt.Errorf("otel: create histogram %q: %w", name, err) + } h, err := b.meter.Float64Histogram(name, metric.WithDescription(desc), metric.WithExplicitBucketBoundaries(buckets...), ) if err != nil { - panic(fmt.Sprintf("otel: create histogram %q: %v", name, err)) + return nil, fmt.Errorf("otel: create histogram %q: %w", name, err) } - return &Histogram{h: h} + return &Histogram{h: h, labelNames: normalizedLabelNames}, nil } -// labelsToAttrs converts a Labels map to OTel attribute key-value pairs. -func labelsToAttrs(labels map[string]string) []attribute.KeyValue { - if len(labels) == 0 { - return nil +func validateLabelNames(labelNames []string) ([]string, error) { + if len(labelNames) == 0 { + return nil, nil } - attrs := make([]attribute.KeyValue, 0, len(labels)) - for k, v := range labels { - attrs = append(attrs, attribute.String(k, v)) + + normalized := make([]string, len(labelNames)) + seen := make(map[string]struct{}, len(labelNames)) + for i, name := range labelNames { + if !metricLabelNameRE.MatchString(name) { + return nil, fmt.Errorf("invalid label name %q", name) + } + if _, exists := seen[name]; exists { + return nil, fmt.Errorf("duplicate label name %q", name) + } + seen[name] = struct{}{} + normalized[i] = name } + + return normalized, nil +} + +func labelsToAttrs(labelNames []string, labels map[string]string) []attribute.KeyValue { + if len(labelNames) == 0 { + if len(labels) > 0 { + log.Printf("WARN: dropping otel metric sample due to unexpected labels: got=%v expected=none", labels) + return nil + } + return []attribute.KeyValue{} + } + + attrs := make([]attribute.KeyValue, 0, len(labelNames)) + for _, labelName := range labelNames { + attrs = append(attrs, attribute.String(labelName, labels[labelName])) + } + + for got := range labels { + found := false + for _, expected := range labelNames { + if got == expected { + found = true + break + } + } + if !found { + log.Printf("WARN: dropping otel metric sample due to unexpected label key %q (expected=%v)", got, labelNames) + return nil + } + } + return attrs } // Counter wraps an OTel Int64Counter. type Counter struct { - c metric.Int64Counter + c metric.Int64Counter + labelNames []string } // Add increments the counter by value. func (c *Counter) Add(ctx context.Context, value int64, labels map[string]string) { - c.c.Add(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) + attrs := labelsToAttrs(c.labelNames, labels) + if attrs == nil { + return + } + c.c.Add(ctx, value, metric.WithAttributes(attrs...)) } // UpDownCounter wraps an OTel Int64UpDownCounter. type UpDownCounter struct { - c metric.Int64UpDownCounter + c metric.Int64UpDownCounter + labelNames []string } // Add adjusts the up-down counter by value. func (u *UpDownCounter) Add(ctx context.Context, value int64, labels map[string]string) { - u.c.Add(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) + attrs := labelsToAttrs(u.labelNames, labels) + if attrs == nil { + return + } + u.c.Add(ctx, value, metric.WithAttributes(attrs...)) } // Int64Gauge wraps an OTel Int64Gauge. type Int64Gauge struct { - g metric.Int64Gauge + g metric.Int64Gauge + labelNames []string } // Record sets the gauge to value. func (g *Int64Gauge) Record(ctx context.Context, value int64, labels map[string]string) { - g.g.Record(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) + attrs := labelsToAttrs(g.labelNames, labels) + if attrs == nil { + return + } + g.g.Record(ctx, value, metric.WithAttributes(attrs...)) } // Float64Gauge wraps an OTel Float64Gauge. type Float64Gauge struct { - g metric.Float64Gauge + g metric.Float64Gauge + labelNames []string } // Record sets the gauge to value. func (g *Float64Gauge) Record(ctx context.Context, value float64, labels map[string]string) { - g.g.Record(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) + attrs := labelsToAttrs(g.labelNames, labels) + if attrs == nil { + return + } + g.g.Record(ctx, value, metric.WithAttributes(attrs...)) } // Histogram wraps an OTel Float64Histogram. type Histogram struct { - h metric.Float64Histogram + h metric.Float64Histogram + labelNames []string } // Record observes value in the histogram. func (h *Histogram) Record(ctx context.Context, value float64, labels map[string]string) { - h.h.Record(ctx, value, metric.WithAttributes(labelsToAttrs(labels)...)) + attrs := labelsToAttrs(h.labelNames, labels) + if attrs == nil { + return + } + h.h.Record(ctx, value, metric.WithAttributes(attrs...)) } diff --git a/internal/observability/otel/backend_test.go b/internal/observability/otel/backend_test.go index e527678..ef753e8 100644 --- a/internal/observability/otel/backend_test.go +++ b/internal/observability/otel/backend_test.go @@ -55,7 +55,10 @@ func TestOtelBackendCounter(t *testing.T) { b := newInMemoryBackend(t) defer b.Shutdown(context.Background()) //nolint:errcheck - c := b.NewCounter("gerbil_test_counter_total", "test counter", "result") + c, err := b.NewCounter("gerbil_test_counter_total", "test counter", "result") + if err != nil { + t.Fatalf("NewCounter returned error: %v", err) + } // Should not panic c.Add(context.Background(), 1, map[string]string{"result": "ok"}) c.Add(context.Background(), 5, nil) @@ -65,7 +68,10 @@ func TestOtelBackendUpDownCounter(t *testing.T) { b := newInMemoryBackend(t) defer b.Shutdown(context.Background()) //nolint:errcheck - u := b.NewUpDownCounter("gerbil_test_updown", "test updown", "state") + u, err := b.NewUpDownCounter("gerbil_test_updown", "test updown", "state") + if err != nil { + t.Fatalf("NewUpDownCounter returned error: %v", err) + } u.Add(context.Background(), 3, map[string]string{"state": "active"}) u.Add(context.Background(), -1, map[string]string{"state": "active"}) } @@ -74,7 +80,10 @@ func TestOtelBackendInt64Gauge(t *testing.T) { b := newInMemoryBackend(t) defer b.Shutdown(context.Background()) //nolint:errcheck - g := b.NewInt64Gauge("gerbil_test_int_gauge", "test gauge") + g, err := b.NewInt64Gauge("gerbil_test_int_gauge", "test gauge") + if err != nil { + t.Fatalf("NewInt64Gauge returned error: %v", err) + } g.Record(context.Background(), 42, nil) } @@ -82,7 +91,10 @@ func TestOtelBackendFloat64Gauge(t *testing.T) { b := newInMemoryBackend(t) defer b.Shutdown(context.Background()) //nolint:errcheck - g := b.NewFloat64Gauge("gerbil_test_float_gauge", "test float gauge") + g, err := b.NewFloat64Gauge("gerbil_test_float_gauge", "test float gauge") + if err != nil { + t.Fatalf("NewFloat64Gauge returned error: %v", err) + } g.Record(context.Background(), 3.14, nil) } @@ -90,8 +102,11 @@ func TestOtelBackendHistogram(t *testing.T) { b := newInMemoryBackend(t) defer b.Shutdown(context.Background()) //nolint:errcheck - h := b.NewHistogram("gerbil_test_duration_seconds", "test histogram", + h, err := b.NewHistogram("gerbil_test_duration_seconds", "test histogram", []float64{0.1, 0.5, 1.0}, "method") + if err != nil { + t.Fatalf("NewHistogram returned error: %v", err) + } h.Record(context.Background(), 0.3, map[string]string{"method": "GET"}) } @@ -139,3 +154,22 @@ func TestOtelBackendDeploymentEnvironment(t *testing.T) { } defer b.Shutdown(context.Background()) //nolint:errcheck } + +func TestOtelBackendRejectsInvalidLabelNames(t *testing.T) { + b := newInMemoryBackend(t) + defer b.Shutdown(context.Background()) //nolint:errcheck + + t.Run("duplicate labels", func(t *testing.T) { + _, err := b.NewCounter("gerbil_test_invalid_labels_total", "test counter", "result", "result") + if err == nil { + t.Fatal("expected error for duplicate label names") + } + }) + + t.Run("invalid label name", func(t *testing.T) { + _, err := b.NewHistogram("gerbil_test_invalid_histogram", "test histogram", []float64{0.1, 1.0}, "status-code") + if err == nil { + t.Fatal("expected error for invalid label name") + } + }) +} diff --git a/internal/observability/otel/exporter.go b/internal/observability/otel/exporter.go index 44fe1e2..89950fa 100644 --- a/internal/observability/otel/exporter.go +++ b/internal/observability/otel/exporter.go @@ -3,6 +3,8 @@ package otel import ( "context" "fmt" + "net/url" + "strings" "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc" "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp" @@ -11,6 +13,10 @@ import ( // newExporter creates the appropriate OTLP exporter based on cfg.Protocol. func newExporter(ctx context.Context, cfg Config) (sdkmetric.Exporter, error) { + if strings.TrimSpace(cfg.Endpoint) == "" { + return nil, fmt.Errorf("otel: cfg.Endpoint is empty") + } + switch cfg.Protocol { case "grpc", "": return newGRPCExporter(ctx, cfg) @@ -36,8 +42,20 @@ func newGRPCExporter(ctx context.Context, cfg Config) (sdkmetric.Exporter, error } func newHTTPExporter(ctx context.Context, cfg Config) (sdkmetric.Exporter, error) { - opts := []otlpmetrichttp.Option{ - otlpmetrichttp.WithEndpoint(cfg.Endpoint), + endpoint := strings.TrimSpace(cfg.Endpoint) + + opts := make([]otlpmetrichttp.Option, 0, 3) + if strings.Contains(endpoint, "://") { + parsed, err := url.Parse(endpoint) + if err != nil { + return nil, fmt.Errorf("otlp http exporter: parse endpoint URL %q: %w", endpoint, err) + } + opts = append(opts, otlpmetrichttp.WithEndpointURL(parsed.String())) + } else { + opts = append(opts, + otlpmetrichttp.WithEndpoint(endpoint), + otlpmetrichttp.WithURLPath("/v1/metrics"), + ) } if cfg.Insecure { opts = append(opts, otlpmetrichttp.WithInsecure()) From 191b4fa26a390b3d5651ed8fe7eb9b5dce76174b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:12:24 +0200 Subject: [PATCH 22/25] feat(prometheus): robust label validation and registration handling --- internal/observability/prometheus/backend.go | 183 +++++++++++++++--- .../observability/prometheus/backend_test.go | 78 +++++++- 2 files changed, 222 insertions(+), 39 deletions(-) diff --git a/internal/observability/prometheus/backend.go b/internal/observability/prometheus/backend.go index f2744c1..a513b88 100644 --- a/internal/observability/prometheus/backend.go +++ b/internal/observability/prometheus/backend.go @@ -7,6 +7,7 @@ package prometheus import ( "context" + "log" "net/http" "github.com/prometheus/client_golang/prometheus" @@ -30,9 +31,10 @@ type Config struct { // in the backend-specific instrument types that implement the observability // instrument interfaces. type Backend struct { - cfg Config - registry *prometheus.Registry - handler http.Handler + cfg Config + registry *prometheus.Registry + handler http.Handler + droppedSamplesCounter prometheus.Counter } // New creates and initialises a Prometheus backend. @@ -48,6 +50,11 @@ func New(cfg Config) (*Backend, error) { } registry := prometheus.NewRegistry() + droppedSamplesCounter := prometheus.NewCounter(prometheus.CounterOpts{ + Name: "gerbil_dropped_metric_samples_total", + Help: "Total number of metric samples dropped due to invalid labels or unsupported label sets", + }) + registry.MustRegister(droppedSamplesCounter) // Include Go and process metrics by default. includeGo := cfg.IncludeGoMetrics == nil || *cfg.IncludeGoMetrics @@ -62,7 +69,7 @@ func New(cfg Config) (*Backend, error) { EnableOpenMetrics: false, }) - return &Backend{cfg: cfg, registry: registry, handler: handler}, nil + return &Backend{cfg: cfg, registry: registry, handler: handler, droppedSamplesCounter: droppedSamplesCounter}, nil } // HTTPHandler returns the Prometheus /metrics HTTP handler. @@ -78,60 +85,107 @@ func (b *Backend) Shutdown(_ context.Context) error { } // NewCounter creates a Prometheus CounterVec registered on the backend's registry. -func (b *Backend) NewCounter(name, desc string, labelNames ...string) *Counter { +func (b *Backend) NewCounter(name, desc string, labelNames ...string) (*Counter, error) { vec := prometheus.NewCounterVec(prometheus.CounterOpts{ Name: name, Help: desc, }, labelNames) - b.registry.MustRegister(vec) - return &Counter{vec: vec} + if err := b.registry.Register(vec); err != nil { + if are, ok := err.(prometheus.AlreadyRegisteredError); ok { + existing, ok := are.ExistingCollector.(*prometheus.CounterVec) + if !ok { + return nil, err + } + return &Counter{vec: existing, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil + } + return nil, err + } + return &Counter{vec: vec, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil } // NewUpDownCounter creates a Prometheus GaugeVec (Prometheus gauges are // bidirectional) registered on the backend's registry. -func (b *Backend) NewUpDownCounter(name, desc string, labelNames ...string) *UpDownCounter { +func (b *Backend) NewUpDownCounter(name, desc string, labelNames ...string) (*UpDownCounter, error) { vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ Name: name, Help: desc, }, labelNames) - b.registry.MustRegister(vec) - return &UpDownCounter{vec: vec} + if err := b.registry.Register(vec); err != nil { + if are, ok := err.(prometheus.AlreadyRegisteredError); ok { + existing, ok := are.ExistingCollector.(*prometheus.GaugeVec) + if !ok { + return nil, err + } + return &UpDownCounter{vec: existing, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil + } + return nil, err + } + return &UpDownCounter{vec: vec, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil } // NewInt64Gauge creates a Prometheus GaugeVec registered on the backend's registry. -func (b *Backend) NewInt64Gauge(name, desc string, labelNames ...string) *Int64Gauge { +func (b *Backend) NewInt64Gauge(name, desc string, labelNames ...string) (*Int64Gauge, error) { vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ Name: name, Help: desc, }, labelNames) - b.registry.MustRegister(vec) - return &Int64Gauge{vec: vec} + if err := b.registry.Register(vec); err != nil { + if are, ok := err.(prometheus.AlreadyRegisteredError); ok { + existing, ok := are.ExistingCollector.(*prometheus.GaugeVec) + if !ok { + return nil, err + } + return &Int64Gauge{vec: existing, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil + } + return nil, err + } + return &Int64Gauge{vec: vec, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil } // NewFloat64Gauge creates a Prometheus GaugeVec registered on the backend's registry. -func (b *Backend) NewFloat64Gauge(name, desc string, labelNames ...string) *Float64Gauge { +func (b *Backend) NewFloat64Gauge(name, desc string, labelNames ...string) (*Float64Gauge, error) { vec := prometheus.NewGaugeVec(prometheus.GaugeOpts{ Name: name, Help: desc, }, labelNames) - b.registry.MustRegister(vec) - return &Float64Gauge{vec: vec} + if err := b.registry.Register(vec); err != nil { + if are, ok := err.(prometheus.AlreadyRegisteredError); ok { + existing, ok := are.ExistingCollector.(*prometheus.GaugeVec) + if !ok { + return nil, err + } + return &Float64Gauge{vec: existing, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil + } + return nil, err + } + return &Float64Gauge{vec: vec, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil } // NewHistogram creates a Prometheus HistogramVec registered on the backend's registry. -func (b *Backend) NewHistogram(name, desc string, buckets []float64, labelNames ...string) *Histogram { +func (b *Backend) NewHistogram(name, desc string, buckets []float64, labelNames ...string) (*Histogram, error) { vec := prometheus.NewHistogramVec(prometheus.HistogramOpts{ Name: name, Help: desc, Buckets: buckets, }, labelNames) - b.registry.MustRegister(vec) - return &Histogram{vec: vec} + if err := b.registry.Register(vec); err != nil { + if are, ok := err.(prometheus.AlreadyRegisteredError); ok { + existing, ok := are.ExistingCollector.(*prometheus.HistogramVec) + if !ok { + return nil, err + } + return &Histogram{vec: existing, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil + } + return nil, err + } + return &Histogram{vec: vec, labelNames: append([]string(nil), labelNames...), droppedSamplesCounter: b.droppedSamplesCounter}, nil } // Counter is a native Prometheus counter instrument. type Counter struct { - vec *prometheus.CounterVec + vec *prometheus.CounterVec + labelNames []string + droppedSamplesCounter prometheus.Counter } // Add increments the counter by value for the given labels. @@ -139,47 +193,118 @@ type Counter struct { // value must be non-negative. Negative values are ignored. func (c *Counter) Add(_ context.Context, value int64, labels map[string]string) { if value < 0 { + log.Printf("WARN: counter add called with negative value=%d labels=%v expected_labels=%v", value, labels, c.labelNames) return } - c.vec.With(prometheus.Labels(labels)).Add(float64(value)) + normalized, ok := normalizeLabels(c.labelNames, labels, c.droppedSamplesCounter) + if !ok { + return + } + defer guardMetricPanic("counter", c.labelNames, labels) + c.vec.With(normalized).Add(float64(value)) } // UpDownCounter is a native Prometheus gauge used as a bidirectional counter. type UpDownCounter struct { - vec *prometheus.GaugeVec + vec *prometheus.GaugeVec + labelNames []string + droppedSamplesCounter prometheus.Counter } // Add adjusts the gauge by value for the given labels. func (u *UpDownCounter) Add(_ context.Context, value int64, labels map[string]string) { - u.vec.With(prometheus.Labels(labels)).Add(float64(value)) + normalized, ok := normalizeLabels(u.labelNames, labels, u.droppedSamplesCounter) + if !ok { + return + } + defer guardMetricPanic("updown", u.labelNames, labels) + u.vec.With(normalized).Add(float64(value)) } // Int64Gauge is a native Prometheus gauge recording integer snapshot values. type Int64Gauge struct { - vec *prometheus.GaugeVec + vec *prometheus.GaugeVec + labelNames []string + droppedSamplesCounter prometheus.Counter } // Record sets the gauge to value for the given labels. func (g *Int64Gauge) Record(_ context.Context, value int64, labels map[string]string) { - g.vec.With(prometheus.Labels(labels)).Set(float64(value)) + normalized, ok := normalizeLabels(g.labelNames, labels, g.droppedSamplesCounter) + if !ok { + return + } + defer guardMetricPanic("int64-gauge", g.labelNames, labels) + g.vec.With(normalized).Set(float64(value)) } // Float64Gauge is a native Prometheus gauge recording float snapshot values. type Float64Gauge struct { - vec *prometheus.GaugeVec + vec *prometheus.GaugeVec + labelNames []string + droppedSamplesCounter prometheus.Counter } // Record sets the gauge to value for the given labels. func (g *Float64Gauge) Record(_ context.Context, value float64, labels map[string]string) { - g.vec.With(prometheus.Labels(labels)).Set(value) + normalized, ok := normalizeLabels(g.labelNames, labels, g.droppedSamplesCounter) + if !ok { + return + } + defer guardMetricPanic("float64-gauge", g.labelNames, labels) + g.vec.With(normalized).Set(value) } // Histogram is a native Prometheus histogram instrument. type Histogram struct { - vec *prometheus.HistogramVec + vec *prometheus.HistogramVec + labelNames []string + droppedSamplesCounter prometheus.Counter } // Record observes value for the given labels. func (h *Histogram) Record(_ context.Context, value float64, labels map[string]string) { - h.vec.With(prometheus.Labels(labels)).Observe(value) + normalized, ok := normalizeLabels(h.labelNames, labels, h.droppedSamplesCounter) + if !ok { + return + } + defer guardMetricPanic("histogram", h.labelNames, labels) + h.vec.With(normalized).Observe(value) +} + +func normalizeLabels(labelNames []string, labels map[string]string, droppedSamplesCounter prometheus.Counter) (prometheus.Labels, bool) { + if len(labelNames) == 0 { + if len(labels) > 0 { + if droppedSamplesCounter != nil { + droppedSamplesCounter.Inc() + } + log.Printf("WARN: dropping metric sample due to unexpected labels: got=%v expected=none", labels) + return nil, false + } + return nil, true + } + + normalized := make(prometheus.Labels, len(labelNames)) + for _, name := range labelNames { + normalized[name] = "" + } + + for k, v := range labels { + if _, ok := normalized[k]; !ok { + if droppedSamplesCounter != nil { + droppedSamplesCounter.Inc() + } + log.Printf("WARN: dropping metric sample due to unexpected label key %q (expected=%v)", k, labelNames) + return nil, false + } + normalized[k] = v + } + + return normalized, true +} + +func guardMetricPanic(kind string, expected []string, labels map[string]string) { + if recovered := recover(); recovered != nil { + log.Printf("WARN: dropped %s metric sample due to label panic: expected=%v got=%v err=%v", kind, expected, labels, recovered) + } } diff --git a/internal/observability/prometheus/backend_test.go b/internal/observability/prometheus/backend_test.go index d60821f..23f4b90 100644 --- a/internal/observability/prometheus/backend_test.go +++ b/internal/observability/prometheus/backend_test.go @@ -36,7 +36,10 @@ func TestPrometheusBackendShutdown(t *testing.T) { func TestPrometheusBackendCounter(t *testing.T) { b := newTestBackend(t) - c := b.NewCounter("test_counter_total", "A test counter", "result") + c, err := b.NewCounter("test_counter_total", "A test counter", "result") + if err != nil { + t.Fatalf("NewCounter returned error: %v", err) + } c.Add(context.Background(), 3, map[string]string{"result": "ok"}) body := scrapeMetrics(t, b) @@ -45,7 +48,10 @@ func TestPrometheusBackendCounter(t *testing.T) { func TestPrometheusBackendUpDownCounter(t *testing.T) { b := newTestBackend(t) - u := b.NewUpDownCounter("test_gauge_total", "A test up-down counter", "state") + u, err := b.NewUpDownCounter("test_gauge_total", "A test up-down counter", "state") + if err != nil { + t.Fatalf("NewUpDownCounter returned error: %v", err) + } u.Add(context.Background(), 5, map[string]string{"state": "active"}) u.Add(context.Background(), -2, map[string]string{"state": "active"}) @@ -55,7 +61,10 @@ func TestPrometheusBackendUpDownCounter(t *testing.T) { func TestPrometheusBackendInt64Gauge(t *testing.T) { b := newTestBackend(t) - g := b.NewInt64Gauge("test_int_gauge", "An integer gauge", "ifname") + g, err := b.NewInt64Gauge("test_int_gauge", "An integer gauge", "ifname") + if err != nil { + t.Fatalf("NewInt64Gauge returned error: %v", err) + } g.Record(context.Background(), 42, map[string]string{"ifname": "wg0"}) body := scrapeMetrics(t, b) @@ -64,7 +73,10 @@ func TestPrometheusBackendInt64Gauge(t *testing.T) { func TestPrometheusBackendFloat64Gauge(t *testing.T) { b := newTestBackend(t) - g := b.NewFloat64Gauge("test_float_gauge", "A float gauge", "cert") + g, err := b.NewFloat64Gauge("test_float_gauge", "A float gauge", "cert") + if err != nil { + t.Fatalf("NewFloat64Gauge returned error: %v", err) + } g.Record(context.Background(), 7.5, map[string]string{"cert": "example.com"}) body := scrapeMetrics(t, b) @@ -74,7 +86,10 @@ func TestPrometheusBackendFloat64Gauge(t *testing.T) { func TestPrometheusBackendHistogram(t *testing.T) { b := newTestBackend(t) buckets := []float64{0.1, 0.5, 1.0, 5.0} - h := b.NewHistogram("test_duration_seconds", "A test histogram", buckets, "method") + h, err := b.NewHistogram("test_duration_seconds", "A test histogram", buckets, "method") + if err != nil { + t.Fatalf("NewHistogram returned error: %v", err) + } h.Record(context.Background(), 0.3, map[string]string{"method": "GET"}) body := scrapeMetrics(t, b) @@ -85,7 +100,10 @@ func TestPrometheusBackendHistogram(t *testing.T) { func TestPrometheusBackendMultipleLabels(t *testing.T) { b := newTestBackend(t) - c := b.NewCounter("multi_label_total", "Multi-label counter", "method", "route", "status_code") + c, err := b.NewCounter("multi_label_total", "Multi-label counter", "method", "route", "status_code") + if err != nil { + t.Fatalf("NewCounter returned error: %v", err) + } c.Add(context.Background(), 1, map[string]string{ "method": "POST", "route": "/api/peers", @@ -122,23 +140,29 @@ func TestPrometheusBackendNoGoMetrics(t *testing.T) { func TestPrometheusBackendNilLabels(t *testing.T) { // Adding with nil labels should not panic (treated as empty map). b := newTestBackend(t) - c := b.NewCounter("nil_labels_total", "counter with no labels") + c, err := b.NewCounter("nil_labels_total", "counter with no labels") + if err != nil { + t.Fatalf("NewCounter returned error: %v", err) + } // nil labels with no label names declared should be safe c.Add(context.Background(), 1, nil) } func TestPrometheusBackendConcurrentAdd(t *testing.T) { b := newTestBackend(t) - c := b.NewCounter("concurrent_total", "concurrent counter", "worker") + c, err := b.NewCounter("concurrent_total", "concurrent counter", "worker") + if err != nil { + t.Fatalf("NewCounter returned error: %v", err) + } done := make(chan struct{}) for i := 0; i < 10; i++ { - go func(_ int) { + go func() { for j := 0; j < 100; j++ { c.Add(context.Background(), 1, map[string]string{"worker": "w"}) } done <- struct{}{} - }(i) + }() } for i := 0; i < 10; i++ { <-done @@ -148,6 +172,40 @@ func TestPrometheusBackendConcurrentAdd(t *testing.T) { assertMetricPresent(t, body, `concurrent_total{worker="w"} 1000`) } +func TestPrometheusBackendAlreadyRegisteredCounter(t *testing.T) { + b := newTestBackend(t) + c1, err := b.NewCounter("dupe_counter_total", "duplicate counter", "result") + if err != nil { + t.Fatalf("first NewCounter returned error: %v", err) + } + c2, err := b.NewCounter("dupe_counter_total", "duplicate counter", "result") + if err != nil { + t.Fatalf("second NewCounter returned error: %v", err) + } + + c1.Add(context.Background(), 1, map[string]string{"result": "ok"}) + c2.Add(context.Background(), 2, map[string]string{"result": "ok"}) + + body := scrapeMetrics(t, b) + assertMetricPresent(t, body, `dupe_counter_total{result="ok"} 3`) +} + +func TestPrometheusBackendInvalidLabelsNoPanic(t *testing.T) { + b := newTestBackend(t) + c, err := b.NewCounter("invalid_labels_total", "invalid labels test", "result") + if err != nil { + t.Fatalf("NewCounter returned error: %v", err) + } + + // Extra label key should be dropped and must not panic. + c.Add(context.Background(), 5, map[string]string{"result": "ok", "unexpected": "x"}) + + body := scrapeMetrics(t, b) + if strings.Contains(body, `invalid_labels_total{result="ok"}`) { + t.Error("invalid label sample should have been dropped") + } +} + // --- helpers --- func scrapeMetrics(t *testing.T, b *obsprom.Backend) string { From cda6fa677295aa48ce1781092675b78c35b793b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:12:24 +0200 Subject: [PATCH 23/25] feat(cli/proxy): add OTLP timeout flag and make proxy metrics resilient --- main.go | 17 ++++++++++++++++- proxy/proxy.go | 20 ++++++++++---------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/main.go b/main.go index 62bfe7c..2a31c2b 100644 --- a/main.go +++ b/main.go @@ -175,6 +175,7 @@ func main() { otelMetricsEndpoint string otelMetricsInsecure bool otelMetricsExportInterval time.Duration + otelMetricsTimeout time.Duration ) interfaceName = os.Getenv("INTERFACE") @@ -229,6 +230,14 @@ func main() { log.Printf("WARN: invalid OTEL_METRICS_EXPORT_INTERVAL=%q: %v", v, err2) } } + otelMetricsTimeout = 10 * time.Second // default + if v := os.Getenv("OTEL_METRICS_TIMEOUT"); v != "" { + if d, err2 := time.ParseDuration(v); err2 == nil { + otelMetricsTimeout = d + } else { + log.Printf("WARN: invalid OTEL_METRICS_TIMEOUT=%q: %v", v, err2) + } + } if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") @@ -322,6 +331,7 @@ func main() { flag.StringVar(&otelMetricsEndpoint, "otel-metrics-endpoint", otelMetricsEndpoint, "OTLP collector endpoint (e.g. localhost:4317)") flag.BoolVar(&otelMetricsInsecure, "otel-metrics-insecure", otelMetricsInsecure, "Disable TLS for OTLP connection") flag.DurationVar(&otelMetricsExportInterval, "otel-metrics-export-interval", otelMetricsExportInterval, "Interval between OTLP metric pushes") + flag.DurationVar(&otelMetricsTimeout, "otel-metrics-timeout", otelMetricsTimeout, "Timeout for OTLP exporter setup") flag.Parse() @@ -347,6 +357,7 @@ func main() { Endpoint: otelMetricsEndpoint, Insecure: otelMetricsInsecure, ExportInterval: otelMetricsExportInterval, + Timeout: otelMetricsTimeout, }, ServiceName: "gerbil", ServiceVersion: "1.0.0", @@ -543,6 +554,8 @@ func main() { // Register metrics endpoint only for Prometheus backend. // OTel backend pushes to a collector; no /metrics endpoint needed. + // Note: metricsPath is registered directly without httpMetricsMiddleware to prevent infinite recursion. + // The metricsHandler must not be wrapped by the middleware, as it would observe its own observation calls. if metricsHandler != nil { http.Handle(metricsPath, metricsHandler) logger.Info("Metrics endpoint enabled at %s", metricsPath) @@ -1162,10 +1175,12 @@ func removePeerInternal(publicKey string) error { // Get current peer info before removing to clear relay connections and bandwidth limits var wgIPs []string + allowedIPsCount := 0 device, err := wgClient.Device(interfaceName) if err == nil { for _, peer := range device.Peers { if peer.PublicKey.String() == publicKey { + allowedIPsCount = len(peer.AllowedIPs) // Extract WireGuard IPs from this peer's allowed IPs for _, allowedIP := range peer.AllowedIPs { wgIPs = append(wgIPs, allowedIP.IP.String()) @@ -1208,7 +1223,7 @@ func removePeerInternal(publicKey string) error { // Record metrics metrics.RecordPeersTotal(interfaceName, -1) - metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(len(wgIPs))) + metrics.RecordAllowedIPsCount(interfaceName, publicKey, -int64(allowedIPsCount)) return nil } diff --git a/proxy/proxy.go b/proxy/proxy.go index 71cf4ed..9b46e10 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -548,7 +548,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { logger.Debug("SNI extraction failed: %v", err) return } - metrics.RecordProxyTLSHandshake(hostname, time.Since(clientHelloStart).Seconds()) + metrics.RecordProxyTLSHandshake(time.Since(clientHelloStart).Seconds()) if hostname == "" { log.Println("No SNI hostname found") @@ -596,8 +596,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { defer targetConn.Close() logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort) - metrics.RecordActiveProxyConnection(hostname, 1) - defer metrics.RecordActiveProxyConnection(hostname, -1) + metrics.RecordActiveProxyConnection(1) + defer metrics.RecordActiveProxyConnection(-1) // Send PROXY protocol header if enabled if p.proxyProtocol { @@ -655,7 +655,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check local overrides first if _, isOverride := p.localOverrides[hostname]; isOverride { logger.Debug("Local override matched for hostname: %s", hostname) - metrics.RecordProxyRouteLookup("local_override", hostname) + metrics.RecordProxyRouteLookup("local_override") return &RouteRecord{ Hostname: hostname, TargetHost: p.localProxyAddr, @@ -668,7 +668,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { _, isLocal := p.localSNIs[hostname] p.localSNIsLock.RUnlock() if isLocal { - metrics.RecordProxyRouteLookup("local", hostname) + metrics.RecordProxyRouteLookup("local") return &RouteRecord{ Hostname: hostname, TargetHost: p.localProxyAddr, @@ -679,16 +679,16 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { // Check cache first if cached, found := p.cache.Get(hostname); found { if cached == nil { - metrics.RecordProxyRouteLookup("cached_not_found", hostname) + metrics.RecordProxyRouteLookup("cached_not_found") return nil, nil // Cached negative result } logger.Debug("Cache hit for hostname: %s", hostname) - metrics.RecordProxyRouteLookup("cache_hit", hostname) + metrics.RecordProxyRouteLookup("cache_hit") return cached.(*RouteRecord), nil } logger.Debug("Cache miss for hostname: %s, querying API", hostname) - metrics.RecordProxyRouteLookup("cache_miss", hostname) + metrics.RecordProxyRouteLookup("cache_miss") // Query API with timeout ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) @@ -822,7 +822,7 @@ func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, client }() bytesCopied, err := io.CopyBuffer(targetConn, clientReader, *bufPtr) - metrics.RecordProxyBytesTransmitted(hostname, "client_to_target", bytesCopied) + metrics.RecordProxyBytesTransmitted("client_to_target", bytesCopied) if err != nil && err != io.EOF { logger.Debug("Copy client->target error: %v", err) } @@ -842,7 +842,7 @@ func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, client }() bytesCopied, err := io.CopyBuffer(clientConn, targetConn, *bufPtr) - metrics.RecordProxyBytesTransmitted(hostname, "target_to_client", bytesCopied) + metrics.RecordProxyBytesTransmitted("target_to_client", bytesCopied) if err != nil && err != io.EOF { logger.Debug("Copy target->client error: %v", err) } From 3f95e2da255c4592a932b3297ed4e256a7bb7aa6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc=20Sch=C3=A4fer?= Date: Mon, 4 May 2026 00:21:10 +0200 Subject: [PATCH 24/25] fix(otel): revert semconv version and correct deployment environment attribute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Marc Schäfer --- internal/observability/otel/resource.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/observability/otel/resource.go b/internal/observability/otel/resource.go index 47a14ff..b0f3b11 100644 --- a/internal/observability/otel/resource.go +++ b/internal/observability/otel/resource.go @@ -3,7 +3,7 @@ package otel import ( "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/sdk/resource" - semconv "go.opentelemetry.io/otel/semconv/v1.40.0" + semconv "go.opentelemetry.io/otel/semconv/v1.18.0" ) // newResource builds an OTel resource for the Gerbil service. @@ -15,11 +15,11 @@ func newResource(serviceName, serviceVersion, deploymentEnv string) (*resource.R attrs = append(attrs, semconv.ServiceVersion(serviceVersion)) } if deploymentEnv != "" { - attrs = append(attrs, semconv.DeploymentEnvironmentName(deploymentEnv)) + attrs = append(attrs, semconv.DeploymentEnvironment(deploymentEnv)) } return resource.Merge( resource.Default(), - resource.NewWithAttributes(semconv.SchemaURL, attrs...), + resource.NewSchemaless(attrs...), ) } From 375cb8b0bae6c0a8e2bf50ec7355cce4e1ab8c6a Mon Sep 17 00:00:00 2001 From: Owen Date: Sun, 3 May 2026 15:26:52 -0700 Subject: [PATCH 25/25] Update CODEOWNERS --- .github/CODEOWNERS | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7d8c330..c5f1403 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,2 +1 @@ * @oschwartz10612 @miloschwartz -internal/observability/** @marcschaeferger