mirror of
https://github.com/fosrl/gerbil.git
synced 2026-05-17 13:49:54 +00:00
Merge remote-tracking branch 'upstream/dev' into proxy-context-tunnel-tracking
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/fosrl/gerbil/internal/metrics"
|
||||
"github.com/fosrl/gerbil/logger"
|
||||
"github.com/patrickmn/go-cache"
|
||||
"golang.org/x/sync/errgroup"
|
||||
@@ -70,6 +71,12 @@ type SNIProxy struct {
|
||||
|
||||
// Trusted upstream proxies that can send PROXY protocol
|
||||
trustedUpstreams map[string]struct{}
|
||||
|
||||
// Reusable HTTP client for API requests
|
||||
httpClient *http.Client
|
||||
|
||||
// Buffer pool for connection piping
|
||||
bufferPool *sync.Pool
|
||||
}
|
||||
|
||||
type activeTunnel struct {
|
||||
@@ -377,6 +384,20 @@ func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, lo
|
||||
localOverrides: overridesMap,
|
||||
activeTunnels: make(map[string]*activeTunnel),
|
||||
trustedUpstreams: trustedMap,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 5 * time.Second,
|
||||
Transport: &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 10,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
},
|
||||
},
|
||||
bufferPool: &sync.Pool{
|
||||
New: func() interface{} {
|
||||
buf := make([]byte, 32*1024)
|
||||
return &buf
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return proxy, nil
|
||||
@@ -490,6 +511,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
defer p.wg.Done()
|
||||
defer clientConn.Close()
|
||||
|
||||
metrics.RecordSNIConnection("accepted")
|
||||
|
||||
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
||||
|
||||
// Check for PROXY protocol from trusted upstream
|
||||
@@ -500,10 +523,12 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
var err error
|
||||
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
||||
if err != nil {
|
||||
metrics.RecordSNIProxyProtocolParseError()
|
||||
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
||||
return
|
||||
}
|
||||
if proxyInfo != nil {
|
||||
metrics.RecordSNITrustedProxyEvent("proxy_protocol_parsed")
|
||||
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
|
||||
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
|
||||
} else {
|
||||
@@ -520,11 +545,13 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
}
|
||||
|
||||
// Extract SNI hostname
|
||||
clientHelloStart := time.Now()
|
||||
hostname, clientReader, err := p.extractSNI(actualClientConn)
|
||||
if err != nil {
|
||||
logger.Debug("SNI extraction failed: %v", err)
|
||||
return
|
||||
}
|
||||
metrics.RecordProxyTLSHandshake(hostname, time.Since(clientHelloStart).Seconds())
|
||||
|
||||
if hostname == "" {
|
||||
log.Println("No SNI hostname found")
|
||||
@@ -572,6 +599,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
defer targetConn.Close()
|
||||
|
||||
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
|
||||
metrics.RecordActiveProxyConnection(hostname, 1)
|
||||
defer metrics.RecordActiveProxyConnection(hostname, -1)
|
||||
|
||||
// Send PROXY protocol header if enabled
|
||||
if p.proxyProtocol {
|
||||
@@ -615,8 +644,8 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
||||
p.activeTunnelsLock.Unlock()
|
||||
}()
|
||||
|
||||
// Start bidirectional data transfer with tunnel context
|
||||
p.pipe(tunnelCtx, actualClientConn, targetConn, clientReader)
|
||||
// Start bidirectional data transfer with tunnel-level cancellation context.
|
||||
p.pipe(tunnelCtx, hostname, actualClientConn, targetConn, clientReader)
|
||||
}
|
||||
|
||||
// getRoute retrieves routing information for a hostname
|
||||
@@ -624,6 +653,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
// Check local overrides first
|
||||
if _, isOverride := p.localOverrides[hostname]; isOverride {
|
||||
logger.Debug("Local override matched for hostname: %s", hostname)
|
||||
metrics.RecordProxyRouteLookup("local_override", hostname)
|
||||
return &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: p.localProxyAddr,
|
||||
@@ -636,6 +666,7 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
_, isLocal := p.localSNIs[hostname]
|
||||
p.localSNIsLock.RUnlock()
|
||||
if isLocal {
|
||||
metrics.RecordProxyRouteLookup("local", hostname)
|
||||
return &RouteRecord{
|
||||
Hostname: hostname,
|
||||
TargetHost: p.localProxyAddr,
|
||||
@@ -646,13 +677,16 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
// Check cache first
|
||||
if cached, found := p.cache.Get(hostname); found {
|
||||
if cached == nil {
|
||||
metrics.RecordProxyRouteLookup("cached_not_found", hostname)
|
||||
return nil, nil // Cached negative result
|
||||
}
|
||||
logger.Debug("Cache hit for hostname: %s", hostname)
|
||||
metrics.RecordProxyRouteLookup("cache_hit", hostname)
|
||||
return cached.(*RouteRecord), nil
|
||||
}
|
||||
|
||||
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
|
||||
metrics.RecordProxyRouteLookup("cache_miss", hostname)
|
||||
|
||||
// Query API with timeout
|
||||
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
|
||||
@@ -680,22 +714,28 @@ func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
// Make HTTP request
|
||||
client := &http.Client{Timeout: 5 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
apiStart := time.Now()
|
||||
// Make HTTP request using reusable client
|
||||
resp, err := p.httpClient.Do(req)
|
||||
if err != nil {
|
||||
metrics.RecordSNIRouteAPIRequest("error")
|
||||
return nil, fmt.Errorf("API request failed: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
metrics.RecordSNIRouteAPILatency(time.Since(apiStart).Seconds())
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
metrics.RecordSNIRouteAPIRequest("not_found")
|
||||
// Cache negative result for shorter time (1 minute)
|
||||
p.cache.Set(hostname, nil, 1*time.Minute)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
metrics.RecordSNIRouteAPIRequest("error")
|
||||
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
|
||||
}
|
||||
metrics.RecordSNIRouteAPIRequest("success")
|
||||
|
||||
// Parse response
|
||||
var apiResponse RouteAPIResponse
|
||||
@@ -752,7 +792,7 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s
|
||||
}
|
||||
|
||||
// pipe handles bidirectional data transfer between connections
|
||||
func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||
func (p *SNIProxy) pipe(ctx context.Context, hostname string, clientConn, targetConn net.Conn, clientReader io.Reader) {
|
||||
g, gCtx := errgroup.WithContext(ctx)
|
||||
|
||||
// Close connections when context cancels to unblock io.Copy operations
|
||||
@@ -761,10 +801,16 @@ func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, cl
|
||||
targetConn.Close()
|
||||
})
|
||||
|
||||
// Copy data from client to target
|
||||
// Copy data from client to target (using buffered reader and pooled memory).
|
||||
g.Go(func() error {
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
||||
bufPtr := p.bufferPool.Get().(*[]byte)
|
||||
defer func() {
|
||||
clear(*bufPtr)
|
||||
p.bufferPool.Put(bufPtr)
|
||||
}()
|
||||
|
||||
bytesCopied, err := io.CopyBuffer(targetConn, clientReader, *bufPtr)
|
||||
metrics.RecordProxyBytesTransmitted(hostname, "client_to_target", bytesCopied)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy client->target error: %v", err)
|
||||
}
|
||||
@@ -773,15 +819,21 @@ func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, cl
|
||||
|
||||
// Copy data from target to client
|
||||
g.Go(func() error {
|
||||
buf := make([]byte, 32*1024)
|
||||
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
||||
bufPtr := p.bufferPool.Get().(*[]byte)
|
||||
defer func() {
|
||||
clear(*bufPtr)
|
||||
p.bufferPool.Put(bufPtr)
|
||||
}()
|
||||
|
||||
bytesCopied, err := io.CopyBuffer(clientConn, targetConn, *bufPtr)
|
||||
metrics.RecordProxyBytesTransmitted(hostname, "target_to_client", bytesCopied)
|
||||
if err != nil && err != io.EOF {
|
||||
logger.Debug("Copy target->client error: %v", err)
|
||||
}
|
||||
return err
|
||||
})
|
||||
|
||||
g.Wait()
|
||||
_ = g.Wait()
|
||||
}
|
||||
|
||||
// GetCacheStats returns cache statistics
|
||||
|
||||
Reference in New Issue
Block a user