mirror of
https://github.com/fosrl/gerbil.git
synced 2026-05-20 15:19:53 +00:00
Enhance metrics tracking in SNIProxy connection handling
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fosrl/gerbil/internal/metrics"
|
||||||
"github.com/fosrl/gerbil/logger"
|
"github.com/fosrl/gerbil/logger"
|
||||||
"github.com/patrickmn/go-cache"
|
"github.com/patrickmn/go-cache"
|
||||||
)
|
)
|
||||||
@@ -487,6 +488,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
defer p.wg.Done()
|
defer p.wg.Done()
|
||||||
defer clientConn.Close()
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
metrics.RecordSNIConnection("accepted")
|
||||||
|
|
||||||
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
||||||
|
|
||||||
// Check for PROXY protocol from trusted upstream
|
// Check for PROXY protocol from trusted upstream
|
||||||
@@ -497,10 +500,12 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
var err error
|
var err error
|
||||||
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
metrics.RecordSNIProxyProtocolParseError()
|
||||||
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if proxyInfo != nil {
|
if proxyInfo != nil {
|
||||||
|
metrics.RecordSNITrustedProxyEvent("proxy_protocol_parsed")
|
||||||
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
|
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
|
||||||
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
|
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
|
||||||
} else {
|
} else {
|
||||||
@@ -517,11 +522,13 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Extract SNI hostname
|
// Extract SNI hostname
|
||||||
|
clientHelloStart := time.Now()
|
||||||
hostname, clientReader, err := p.extractSNI(actualClientConn)
|
hostname, clientReader, err := p.extractSNI(actualClientConn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Debug("SNI extraction failed: %v", err)
|
logger.Debug("SNI extraction failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
metrics.RecordProxyTLSHandshake(hostname, time.Since(clientHelloStart).Seconds())
|
||||||
|
|
||||||
if hostname == "" {
|
if hostname == "" {
|
||||||
log.Println("No SNI hostname found")
|
log.Println("No SNI hostname found")
|
||||||
@@ -569,6 +576,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
defer targetConn.Close()
|
defer targetConn.Close()
|
||||||
|
|
||||||
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
|
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
|
||||||
|
metrics.RecordActiveProxyConnection(hostname, 1)
|
||||||
|
defer metrics.RecordActiveProxyConnection(hostname, -1)
|
||||||
|
|
||||||
// Send PROXY protocol header if enabled
|
// Send PROXY protocol header if enabled
|
||||||
if p.proxyProtocol {
|
if p.proxyProtocol {
|
||||||
@@ -618,7 +627,7 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Start bidirectional data transfer
|
// Start bidirectional data transfer
|
||||||
p.pipe(actualClientConn, targetConn, clientReader)
|
p.pipe(hostname, actualClientConn, targetConn, clientReader)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getRoute retrieves routing information for a hostname
|
// getRoute retrieves routing information for a hostname
|
||||||
@@ -626,6 +635,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
|||||||
// Check local overrides first
|
// Check local overrides first
|
||||||
if _, isOverride := p.localOverrides[hostname]; isOverride {
|
if _, isOverride := p.localOverrides[hostname]; isOverride {
|
||||||
logger.Debug("Local override matched for hostname: %s", hostname)
|
logger.Debug("Local override matched for hostname: %s", hostname)
|
||||||
|
metrics.RecordProxyRouteLookup("local_override", hostname)
|
||||||
return &RouteRecord{
|
return &RouteRecord{
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
TargetHost: p.localProxyAddr,
|
TargetHost: p.localProxyAddr,
|
||||||
@@ -638,6 +648,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
|||||||
_, isLocal := p.localSNIs[hostname]
|
_, isLocal := p.localSNIs[hostname]
|
||||||
p.localSNIsLock.RUnlock()
|
p.localSNIsLock.RUnlock()
|
||||||
if isLocal {
|
if isLocal {
|
||||||
|
metrics.RecordProxyRouteLookup("local", hostname)
|
||||||
return &RouteRecord{
|
return &RouteRecord{
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
TargetHost: p.localProxyAddr,
|
TargetHost: p.localProxyAddr,
|
||||||
@@ -648,13 +659,16 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
|||||||
// Check cache first
|
// Check cache first
|
||||||
if cached, found := p.cache.Get(hostname); found {
|
if cached, found := p.cache.Get(hostname); found {
|
||||||
if cached == nil {
|
if cached == nil {
|
||||||
|
metrics.RecordProxyRouteLookup("cached_not_found", hostname)
|
||||||
return nil, nil // Cached negative result
|
return nil, nil // Cached negative result
|
||||||
}
|
}
|
||||||
logger.Debug("Cache hit for hostname: %s", hostname)
|
logger.Debug("Cache hit for hostname: %s", hostname)
|
||||||
|
metrics.RecordProxyRouteLookup("cache_hit", hostname)
|
||||||
return cached.(*RouteRecord), nil
|
return cached.(*RouteRecord), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
|
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
|
||||||
|
metrics.RecordProxyRouteLookup("cache_miss", hostname)
|
||||||
|
|
||||||
// Query API with timeout
|
// Query API with timeout
|
||||||
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
|
||||||
@@ -682,22 +696,28 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
|||||||
req.Header.Set("Content-Type", "application/json")
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
// Make HTTP request
|
// Make HTTP request
|
||||||
|
apiStart := time.Now()
|
||||||
client := &http.Client{Timeout: 5 * time.Second}
|
client := &http.Client{Timeout: 5 * time.Second}
|
||||||
resp, err := client.Do(req)
|
resp, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
metrics.RecordSNIRouteAPIRequest("error")
|
||||||
return nil, fmt.Errorf("API request failed: %w", err)
|
return nil, fmt.Errorf("API request failed: %w", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
metrics.RecordSNIRouteAPILatency(time.Since(apiStart).Seconds())
|
||||||
|
|
||||||
if resp.StatusCode == http.StatusNotFound {
|
if resp.StatusCode == http.StatusNotFound {
|
||||||
|
metrics.RecordSNIRouteAPIRequest("not_found")
|
||||||
// Cache negative result for shorter time (1 minute)
|
// Cache negative result for shorter time (1 minute)
|
||||||
p.cache.Set(hostname, nil, 1*time.Minute)
|
p.cache.Set(hostname, nil, 1*time.Minute)
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
metrics.RecordSNIRouteAPIRequest("error")
|
||||||
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
|
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
metrics.RecordSNIRouteAPIRequest("success")
|
||||||
|
|
||||||
// Parse response
|
// Parse response
|
||||||
var apiResponse RouteAPIResponse
|
var apiResponse RouteAPIResponse
|
||||||
@@ -754,7 +774,7 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s
|
|||||||
}
|
}
|
||||||
|
|
||||||
// pipe handles bidirectional data transfer between connections
|
// pipe handles bidirectional data transfer between connections
|
||||||
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
|
func (p *SNIProxy) pipe(hostname string, clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
wg.Add(2)
|
wg.Add(2)
|
||||||
|
|
||||||
@@ -775,7 +795,8 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
|
|||||||
|
|
||||||
// Use a large buffer for better performance
|
// Use a large buffer for better performance
|
||||||
buf := make([]byte, 32*1024)
|
buf := make([]byte, 32*1024)
|
||||||
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
bytesCopied, err := io.CopyBuffer(targetConn, clientReader, buf)
|
||||||
|
metrics.RecordProxyBytesTransmitted(hostname, "client_to_target", bytesCopied)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logger.Debug("Copy client->target error: %v", err)
|
logger.Debug("Copy client->target error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -788,7 +809,8 @@ func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader)
|
|||||||
|
|
||||||
// Use a large buffer for better performance
|
// Use a large buffer for better performance
|
||||||
buf := make([]byte, 32*1024)
|
buf := make([]byte, 32*1024)
|
||||||
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
bytesCopied, err := io.CopyBuffer(clientConn, targetConn, buf)
|
||||||
|
metrics.RecordProxyBytesTransmitted(hostname, "target_to_client", bytesCopied)
|
||||||
if err != nil && err != io.EOF {
|
if err != nil && err != io.EOF {
|
||||||
logger.Debug("Copy target->client error: %v", err)
|
logger.Debug("Copy target->client error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user