Enhance metrics tracking in SNIProxy connection handling

This commit is contained in:
Marc Schäfer
2026-04-03 18:15:41 +02:00
parent 4357ddf64b
commit e47a57cb4f

View File

@@ -16,6 +16,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/fosrl/gerbil/internal/metrics"
"github.com/fosrl/gerbil/logger" "github.com/fosrl/gerbil/logger"
"github.com/patrickmn/go-cache" "github.com/patrickmn/go-cache"
) )
@@ -487,6 +488,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
defer p.wg.Done() defer p.wg.Done()
defer clientConn.Close() defer clientConn.Close()
metrics.RecordSNIConnection("accepted")
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr()) logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
// Check for PROXY protocol from trusted upstream // Check for PROXY protocol from trusted upstream
@@ -497,10 +500,12 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
var err error var err error
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn) proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
if err != nil { if err != nil {
metrics.RecordSNIProxyProtocolParseError()
logger.Debug("Failed to parse PROXY protocol: %v", err) logger.Debug("Failed to parse PROXY protocol: %v", err)
return return
} }
if proxyInfo != nil { if proxyInfo != nil {
metrics.RecordSNITrustedProxyEvent("proxy_protocol_parsed")
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d", logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort) proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
} else { } else {
@@ -517,11 +522,13 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
} }
// Extract SNI hostname // Extract SNI hostname
clientHelloStart := time.Now()
hostname, clientReader, err := p.extractSNI(actualClientConn) hostname, clientReader, err := p.extractSNI(actualClientConn)
if err != nil { if err != nil {
logger.Debug("SNI extraction failed: %v", err) logger.Debug("SNI extraction failed: %v", err)
return return
} }
metrics.RecordProxyTLSHandshake(hostname, time.Since(clientHelloStart).Seconds())
if hostname == "" { if hostname == "" {
log.Println("No SNI hostname found") log.Println("No SNI hostname found")
@@ -569,6 +576,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
defer targetConn.Close() defer targetConn.Close()
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort) logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
metrics.RecordActiveProxyConnection(hostname, 1)
defer metrics.RecordActiveProxyConnection(hostname, -1)
// Send PROXY protocol header if enabled // Send PROXY protocol header if enabled
if p.proxyProtocol { if p.proxyProtocol {
@@ -618,7 +627,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
}() }()
// Start bidirectional data transfer // Start bidirectional data transfer
p.pipe(actualClientConn, targetConn, clientReader) p.pipe(hostname, actualClientConn, targetConn, clientReader)
} }
// getRoute retrieves routing information for a hostname // getRoute retrieves routing information for a hostname
@@ -626,6 +635,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check local overrides first // Check local overrides first
if _, isOverride := p.localOverrides[hostname]; isOverride { if _, isOverride := p.localOverrides[hostname]; isOverride {
logger.Debug("Local override matched for hostname: %s", hostname) logger.Debug("Local override matched for hostname: %s", hostname)
metrics.RecordProxyRouteLookup("local_override", hostname)
return &RouteRecord{ return &RouteRecord{
Hostname: hostname, Hostname: hostname,
TargetHost: p.localProxyAddr, TargetHost: p.localProxyAddr,
@@ -638,6 +648,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
_, isLocal := p.localSNIs[hostname] _, isLocal := p.localSNIs[hostname]
p.localSNIsLock.RUnlock() p.localSNIsLock.RUnlock()
if isLocal { if isLocal {
metrics.RecordProxyRouteLookup("local", hostname)
return &RouteRecord{ return &RouteRecord{
Hostname: hostname, Hostname: hostname,
TargetHost: p.localProxyAddr, TargetHost: p.localProxyAddr,
@@ -648,13 +659,16 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check cache first // Check cache first
if cached, found := p.cache.Get(hostname); found { if cached, found := p.cache.Get(hostname); found {
if cached == nil { if cached == nil {
metrics.RecordProxyRouteLookup("cached_not_found", hostname)
return nil, nil // Cached negative result return nil, nil // Cached negative result
} }
logger.Debug("Cache hit for hostname: %s", hostname) logger.Debug("Cache hit for hostname: %s", hostname)
metrics.RecordProxyRouteLookup("cache_hit", hostname)
return cached.(*RouteRecord), nil return cached.(*RouteRecord), nil
} }
logger.Debug("Cache miss for hostname: %s, querying API", hostname) logger.Debug("Cache miss for hostname: %s, querying API", hostname)
metrics.RecordProxyRouteLookup("cache_miss", hostname)
// Query API with timeout // Query API with timeout
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
@@ -682,22 +696,28 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
// Make HTTP request // Make HTTP request
apiStart := time.Now()
client := &http.Client{Timeout: 5 * time.Second} client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
metrics.RecordSNIRouteAPIRequest("error")
return nil, fmt.Errorf("API request failed: %w", err) return nil, fmt.Errorf("API request failed: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
metrics.RecordSNIRouteAPILatency(time.Since(apiStart).Seconds())
if resp.StatusCode == http.StatusNotFound { if resp.StatusCode == http.StatusNotFound {
metrics.RecordSNIRouteAPIRequest("not_found")
// Cache negative result for shorter time (1 minute) // Cache negative result for shorter time (1 minute)
p.cache.Set(hostname, nil, 1*time.Minute) p.cache.Set(hostname, nil, 1*time.Minute)
return nil, nil return nil, nil
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
metrics.RecordSNIRouteAPIRequest("error")
return nil, fmt.Errorf("API returned status %d", resp.StatusCode) return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
} }
metrics.RecordSNIRouteAPIRequest("success")
// Parse response // Parse response
var apiResponse RouteAPIResponse var apiResponse RouteAPIResponse
@@ -754,7 +774,7 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s
} }
// pipe handles bidirectional data transfer between connections // pipe handles bidirectional data transfer between connections
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, clientReader io.Reader) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(2)
@@ -775,7 +795,8 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
// Use a large buffer for better performance // Use a large buffer for better performance
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(targetConn, clientReader, buf) bytesCopied, err := io.CopyBuffer(targetConn, clientReader, buf)
metrics.RecordProxyBytesTransmitted(hostname, "client_to_target", bytesCopied)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logger.Debug("Copy client->target error: %v", err) logger.Debug("Copy client->target error: %v", err)
} }
@@ -788,7 +809,8 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
// Use a large buffer for better performance // Use a large buffer for better performance
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(clientConn, targetConn, buf) bytesCopied, err := io.CopyBuffer(clientConn, targetConn, buf)
metrics.RecordProxyBytesTransmitted(hostname, "target_to_client", bytesCopied)
if err != nil && err != io.EOF { if err != nil && err != io.EOF {
logger.Debug("Copy target->client error: %v", err) logger.Debug("Copy target->client error: %v", err)
} }