mirror of
https://github.com/fosrl/gerbil.git
synced 2026-03-27 04:56:42 +00:00
- Replace []net.Conn slice with context + atomic counter in activeTunnel - Use errgroup.WithContext for pipe() to handle goroutine lifecycle - Use context.AfterFunc to close connections on cancellation - Fix race condition by comparing tunnel pointers instead of map lookup - UpdateLocalSNIs now cancels tunnel context instead of iterating conns This eliminates O(n) connection removal, prevents goroutine leaks, and provides cleaner cancellation semantics.
835 lines
24 KiB
Go
835 lines
24 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"encoding/json"
|
|
"fmt"
|
|
"hash/fnv"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/fosrl/gerbil/logger"
|
|
"github.com/patrickmn/go-cache"
|
|
"golang.org/x/sync/errgroup"
|
|
)
|
|
|
|
// RouteRecord represents a routing configuration
|
|
type RouteRecord struct {
|
|
Hostname string
|
|
TargetHost string
|
|
TargetPort int
|
|
}
|
|
|
|
// RouteAPIResponse represents the response from the route API
|
|
type RouteAPIResponse struct {
|
|
Endpoints []string `json:"endpoints"`
|
|
}
|
|
|
|
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
|
|
type ProxyProtocolInfo struct {
|
|
Protocol string // TCP4 or TCP6
|
|
SrcIP string
|
|
DestIP string
|
|
SrcPort int
|
|
DestPort int
|
|
OriginalConn net.Conn // The original connection after PROXY protocol parsing
|
|
}
|
|
|
|
// SNIProxy represents the main proxy server
|
|
type SNIProxy struct {
|
|
port int
|
|
cache *cache.Cache
|
|
listener net.Listener
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
wg sync.WaitGroup
|
|
localProxyAddr string
|
|
localProxyPort int
|
|
remoteConfigURL string
|
|
publicKey string
|
|
proxyProtocol bool // Enable PROXY protocol v1
|
|
|
|
// New fields for fast local SNI lookup
|
|
localSNIs map[string]struct{}
|
|
localSNIsLock sync.RWMutex
|
|
|
|
// Local overrides for domains that should always use local proxy
|
|
localOverrides map[string]struct{}
|
|
|
|
// Track active tunnels by SNI
|
|
activeTunnels map[string]*activeTunnel
|
|
activeTunnelsLock sync.Mutex
|
|
|
|
// Trusted upstream proxies that can send PROXY protocol
|
|
trustedUpstreams map[string]struct{}
|
|
}
|
|
|
|
type activeTunnel struct {
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
count atomic.Int64
|
|
}
|
|
|
|
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
|
|
type readOnlyConn struct {
|
|
reader io.Reader
|
|
}
|
|
|
|
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
|
|
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
|
|
func (conn readOnlyConn) Close() error { return nil }
|
|
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
|
|
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
|
|
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
|
|
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
|
|
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
|
|
|
|
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
|
|
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
|
|
// Check if the connection comes from a trusted upstream
|
|
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
|
|
if err != nil {
|
|
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
|
|
}
|
|
|
|
// Resolve the remote IP to hostname to check if it's trusted
|
|
// For simplicity, we'll check the IP directly in trusted upstreams
|
|
// In production, you might want to do reverse DNS lookup
|
|
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
|
|
// Not from trusted upstream, return original connection
|
|
return nil, conn, nil
|
|
}
|
|
|
|
// Set read timeout for PROXY protocol parsing
|
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
|
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
|
|
}
|
|
|
|
// Read the first line (PROXY protocol header)
|
|
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
|
|
n, err := conn.Read(buffer)
|
|
if err != nil {
|
|
// If we can't read from trusted upstream, treat as regular connection
|
|
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
|
|
// Clear read timeout before returning
|
|
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
|
|
logger.Debug("Failed to clear read deadline: %v", clearErr)
|
|
}
|
|
return nil, conn, nil
|
|
}
|
|
|
|
// Find the end of the first line (CRLF)
|
|
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
|
|
if headerEnd == -1 {
|
|
// No PROXY protocol header found, treat as regular TLS connection
|
|
// Return the connection with the buffered data prepended
|
|
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
|
|
|
|
// Clear read timeout
|
|
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
|
logger.Debug("Failed to clear read deadline: %v", err)
|
|
}
|
|
|
|
// Create a reader that includes the buffered data + original connection
|
|
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
|
wrappedConn := &proxyProtocolConn{
|
|
Conn: conn,
|
|
reader: newReader,
|
|
}
|
|
return nil, wrappedConn, nil
|
|
}
|
|
|
|
headerLine := string(buffer[:headerEnd])
|
|
remainingData := buffer[headerEnd+2 : n]
|
|
|
|
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
|
|
parts := strings.Fields(headerLine)
|
|
if len(parts) != 6 || parts[0] != "PROXY" {
|
|
// Check for PROXY UNKNOWN
|
|
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
|
|
// PROXY UNKNOWN - use original connection info
|
|
return nil, conn, nil
|
|
}
|
|
// Invalid PROXY protocol, but might be regular TLS - treat as such
|
|
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
|
|
|
|
// Clear read timeout
|
|
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
|
logger.Debug("Failed to clear read deadline: %v", err)
|
|
}
|
|
|
|
// Return the connection with all buffered data prepended
|
|
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
|
|
wrappedConn := &proxyProtocolConn{
|
|
Conn: conn,
|
|
reader: newReader,
|
|
}
|
|
return nil, wrappedConn, nil
|
|
}
|
|
|
|
protocol := parts[1]
|
|
srcIP := parts[2]
|
|
destIP := parts[3]
|
|
srcPort, err := strconv.Atoi(parts[4])
|
|
if err != nil {
|
|
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
|
|
}
|
|
destPort, err := strconv.Atoi(parts[5])
|
|
if err != nil {
|
|
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
|
|
}
|
|
|
|
// Create a new reader that includes remaining data + original connection
|
|
var newReader io.Reader
|
|
if len(remainingData) > 0 {
|
|
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
|
|
} else {
|
|
newReader = conn
|
|
}
|
|
|
|
// Create a wrapper connection that reads from the combined reader
|
|
wrappedConn := &proxyProtocolConn{
|
|
Conn: conn,
|
|
reader: newReader,
|
|
}
|
|
|
|
proxyInfo := &ProxyProtocolInfo{
|
|
Protocol: protocol,
|
|
SrcIP: srcIP,
|
|
DestIP: destIP,
|
|
SrcPort: srcPort,
|
|
DestPort: destPort,
|
|
OriginalConn: wrappedConn,
|
|
}
|
|
|
|
// Clear read timeout
|
|
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
|
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
|
|
}
|
|
|
|
return proxyInfo, wrappedConn, nil
|
|
}
|
|
|
|
// proxyProtocolConn wraps a connection to read from a custom reader
|
|
type proxyProtocolConn struct {
|
|
net.Conn
|
|
reader io.Reader
|
|
}
|
|
|
|
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
|
|
return c.reader.Read(b)
|
|
}
|
|
|
|
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
|
|
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
|
|
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
|
if !ok {
|
|
// Fallback for unknown address types
|
|
return "PROXY UNKNOWN\r\n"
|
|
}
|
|
|
|
// Use the original client information from the PROXY protocol
|
|
var targetIP string
|
|
var protocol string
|
|
|
|
// Parse source IP to determine protocol family
|
|
srcIP := net.ParseIP(proxyInfo.SrcIP)
|
|
if srcIP == nil {
|
|
return "PROXY UNKNOWN\r\n"
|
|
}
|
|
|
|
if srcIP.To4() != nil {
|
|
// Source is IPv4, use TCP4 protocol
|
|
protocol = "TCP4"
|
|
if targetTCP.IP.To4() != nil {
|
|
// Target is also IPv4, use as-is
|
|
targetIP = targetTCP.IP.String()
|
|
} else {
|
|
// Target is IPv6, but we need IPv4 for consistent protocol family
|
|
if targetTCP.IP.IsLoopback() {
|
|
targetIP = "127.0.0.1"
|
|
} else {
|
|
targetIP = "127.0.0.1" // Safe fallback
|
|
}
|
|
}
|
|
} else {
|
|
// Source is IPv6, use TCP6 protocol
|
|
protocol = "TCP6"
|
|
if targetTCP.IP.To4() != nil {
|
|
// Target is IPv4, convert to IPv6 representation
|
|
targetIP = "::ffff:" + targetTCP.IP.String()
|
|
} else {
|
|
// Target is also IPv6, use as-is
|
|
targetIP = targetTCP.IP.String()
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
|
protocol,
|
|
proxyInfo.SrcIP,
|
|
targetIP,
|
|
proxyInfo.SrcPort,
|
|
targetTCP.Port)
|
|
}
|
|
|
|
// buildProxyProtocolHeader creates a PROXY protocol v1 header
|
|
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
|
|
clientTCP, ok := clientAddr.(*net.TCPAddr)
|
|
if !ok {
|
|
// Fallback for unknown address types
|
|
return "PROXY UNKNOWN\r\n"
|
|
}
|
|
|
|
targetTCP, ok := targetAddr.(*net.TCPAddr)
|
|
if !ok {
|
|
// Fallback for unknown address types
|
|
return "PROXY UNKNOWN\r\n"
|
|
}
|
|
|
|
// Determine protocol family based on client IP and normalize target IP accordingly
|
|
var protocol string
|
|
var targetIP string
|
|
|
|
if clientTCP.IP.To4() != nil {
|
|
// Client is IPv4, use TCP4 protocol
|
|
protocol = "TCP4"
|
|
if targetTCP.IP.To4() != nil {
|
|
// Target is also IPv4, use as-is
|
|
targetIP = targetTCP.IP.String()
|
|
} else {
|
|
// Target is IPv6, but we need IPv4 for consistent protocol family
|
|
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
|
|
if targetTCP.IP.IsLoopback() {
|
|
targetIP = "127.0.0.1"
|
|
} else {
|
|
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
|
|
// or fall back to a sensible IPv4 address based on the target
|
|
targetIP = "127.0.0.1" // Safe fallback
|
|
}
|
|
}
|
|
} else {
|
|
// Client is IPv6, use TCP6 protocol
|
|
protocol = "TCP6"
|
|
if targetTCP.IP.To4() != nil {
|
|
// Target is IPv4, convert to IPv6 representation
|
|
targetIP = "::ffff:" + targetTCP.IP.String()
|
|
} else {
|
|
// Target is also IPv6, use as-is
|
|
targetIP = targetTCP.IP.String()
|
|
}
|
|
}
|
|
|
|
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
|
|
protocol,
|
|
clientTCP.IP.String(),
|
|
targetIP,
|
|
clientTCP.Port,
|
|
targetTCP.Port)
|
|
}
|
|
|
|
// NewSNIProxy creates a new SNI proxy instance
|
|
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
// Create local overrides map
|
|
overridesMap := make(map[string]struct{})
|
|
for _, domain := range localOverrides {
|
|
if domain != "" {
|
|
overridesMap[domain] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Create trusted upstreams map
|
|
trustedMap := make(map[string]struct{})
|
|
for _, upstream := range trustedUpstreams {
|
|
if upstream != "" {
|
|
// Add both the domain and potentially resolved IPs
|
|
trustedMap[upstream] = struct{}{}
|
|
|
|
// Try to resolve the domain to IPs and add them too
|
|
if ips, err := net.LookupIP(upstream); err == nil {
|
|
for _, ip := range ips {
|
|
trustedMap[ip.String()] = struct{}{}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
proxy := &SNIProxy{
|
|
port: port,
|
|
cache: cache.New(3*time.Second, 10*time.Minute),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
localProxyAddr: localProxyAddr,
|
|
localProxyPort: localProxyPort,
|
|
remoteConfigURL: remoteConfigURL,
|
|
publicKey: publicKey,
|
|
proxyProtocol: proxyProtocol,
|
|
localSNIs: make(map[string]struct{}),
|
|
localOverrides: overridesMap,
|
|
activeTunnels: make(map[string]*activeTunnel),
|
|
trustedUpstreams: trustedMap,
|
|
}
|
|
|
|
return proxy, nil
|
|
}
|
|
|
|
// Start begins listening for connections
|
|
func (p *SNIProxy) Start() error {
|
|
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
|
|
}
|
|
|
|
p.listener = listener
|
|
logger.Debug("SNI Proxy listening on port %d", p.port)
|
|
|
|
// Accept connections in a goroutine
|
|
go p.acceptConnections()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Stop gracefully shuts down the proxy
|
|
func (p *SNIProxy) Stop() error {
|
|
log.Println("Stopping SNI Proxy...")
|
|
|
|
p.cancel()
|
|
|
|
if p.listener != nil {
|
|
p.listener.Close()
|
|
}
|
|
|
|
// Wait for all goroutines to finish with timeout
|
|
done := make(chan struct{})
|
|
go func() {
|
|
p.wg.Wait()
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-done:
|
|
log.Println("All connections closed gracefully")
|
|
case <-time.After(30 * time.Second):
|
|
log.Println("Timeout waiting for connections to close")
|
|
}
|
|
|
|
log.Println("SNI Proxy stopped")
|
|
return nil
|
|
}
|
|
|
|
// acceptConnections handles incoming connections
|
|
func (p *SNIProxy) acceptConnections() {
|
|
for {
|
|
conn, err := p.listener.Accept()
|
|
if err != nil {
|
|
select {
|
|
case <-p.ctx.Done():
|
|
return
|
|
default:
|
|
logger.Debug("Accept error: %v", err)
|
|
continue
|
|
}
|
|
}
|
|
|
|
p.wg.Add(1)
|
|
go p.handleConnection(conn)
|
|
}
|
|
}
|
|
|
|
// readClientHello reads and parses the TLS ClientHello message
|
|
func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
|
|
var hello *tls.ClientHelloInfo
|
|
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
|
|
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
|
|
hello = new(tls.ClientHelloInfo)
|
|
*hello = *argHello
|
|
return nil, nil
|
|
},
|
|
}).Handshake()
|
|
if hello == nil {
|
|
return nil, err
|
|
}
|
|
return hello, nil
|
|
}
|
|
|
|
// peekClientHello reads the ClientHello while preserving the data for forwarding
|
|
func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
|
|
peekedBytes := new(bytes.Buffer)
|
|
hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes))
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return hello, io.MultiReader(peekedBytes, reader), nil
|
|
}
|
|
|
|
// extractSNI extracts the SNI hostname from the TLS ClientHello
|
|
func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) {
|
|
clientHello, clientReader, err := p.peekClientHello(conn)
|
|
if err != nil {
|
|
return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err)
|
|
}
|
|
|
|
if clientHello.ServerName == "" {
|
|
return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello")
|
|
}
|
|
|
|
return clientHello.ServerName, clientReader, nil
|
|
}
|
|
|
|
// handleConnection processes a single client connection
|
|
func (p *SNIProxy) handleConnection(clientConn net.Conn) {
|
|
defer p.wg.Done()
|
|
defer clientConn.Close()
|
|
|
|
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
|
|
|
|
// Check for PROXY protocol from trusted upstream
|
|
var proxyInfo *ProxyProtocolInfo
|
|
var actualClientConn net.Conn = clientConn
|
|
|
|
if len(p.trustedUpstreams) > 0 {
|
|
var err error
|
|
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
|
|
if err != nil {
|
|
logger.Debug("Failed to parse PROXY protocol: %v", err)
|
|
return
|
|
}
|
|
if proxyInfo != nil {
|
|
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
|
|
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
|
|
} else {
|
|
// No PROXY protocol detected, but connection is from trusted upstream
|
|
// This is fine - treat as regular connection
|
|
logger.Debug("No PROXY protocol detected from trusted upstream, treating as regular connection")
|
|
}
|
|
}
|
|
|
|
// Set read timeout for SNI extraction
|
|
if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
|
logger.Debug("Failed to set read deadline: %v", err)
|
|
return
|
|
}
|
|
|
|
// Extract SNI hostname
|
|
hostname, clientReader, err := p.extractSNI(actualClientConn)
|
|
if err != nil {
|
|
logger.Debug("SNI extraction failed: %v", err)
|
|
return
|
|
}
|
|
|
|
if hostname == "" {
|
|
log.Println("No SNI hostname found")
|
|
return
|
|
}
|
|
|
|
logger.Debug("SNI hostname detected: %s", hostname)
|
|
|
|
// Remove read timeout for normal operation
|
|
if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil {
|
|
logger.Debug("Failed to clear read deadline: %v", err)
|
|
return
|
|
}
|
|
|
|
// Get routing information - use original client address if available from PROXY protocol
|
|
var clientAddrStr string
|
|
if proxyInfo != nil {
|
|
clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort)
|
|
} else {
|
|
clientAddrStr = clientConn.RemoteAddr().String()
|
|
}
|
|
|
|
route, err := p.getRoute(hostname, clientAddrStr)
|
|
if err != nil {
|
|
logger.Debug("Failed to get route for %s: %v", hostname, err)
|
|
return
|
|
}
|
|
|
|
if route == nil {
|
|
logger.Debug("No route found for hostname: %s", hostname)
|
|
return
|
|
}
|
|
|
|
logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort)
|
|
|
|
// Connect to target server
|
|
targetConn, err := net.DialTimeout("tcp",
|
|
fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort),
|
|
10*time.Second)
|
|
if err != nil {
|
|
logger.Debug("Failed to connect to target %s:%d: %v",
|
|
route.TargetHost, route.TargetPort, err)
|
|
return
|
|
}
|
|
defer targetConn.Close()
|
|
|
|
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
|
|
|
|
// Send PROXY protocol header if enabled
|
|
if p.proxyProtocol {
|
|
var proxyHeader string
|
|
if proxyInfo != nil {
|
|
// Use original client info from PROXY protocol
|
|
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
|
|
} else {
|
|
// Use direct client connection info
|
|
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
|
|
}
|
|
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
|
|
|
|
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
|
|
logger.Debug("Failed to send PROXY protocol header: %v", err)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Track this tunnel by SNI using context for cancellation
|
|
p.activeTunnelsLock.Lock()
|
|
tunnel, ok := p.activeTunnels[hostname]
|
|
if !ok {
|
|
ctx, cancel := context.WithCancel(p.ctx)
|
|
tunnel = &activeTunnel{ctx: ctx, cancel: cancel}
|
|
p.activeTunnels[hostname] = tunnel
|
|
}
|
|
tunnel.count.Add(1)
|
|
tunnelCtx := tunnel.ctx
|
|
p.activeTunnelsLock.Unlock()
|
|
|
|
defer func() {
|
|
// Decrement count atomically; if we're the last connection, clean up
|
|
if tunnel.count.Add(-1) == 0 {
|
|
tunnel.cancel()
|
|
p.activeTunnelsLock.Lock()
|
|
// Only delete if the map still points to our tunnel
|
|
if p.activeTunnels[hostname] == tunnel {
|
|
delete(p.activeTunnels, hostname)
|
|
}
|
|
p.activeTunnelsLock.Unlock()
|
|
}
|
|
}()
|
|
|
|
// Start bidirectional data transfer with tunnel context
|
|
p.pipe(tunnelCtx, actualClientConn, targetConn, clientReader)
|
|
}
|
|
|
|
// getRoute retrieves routing information for a hostname
|
|
func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
|
|
// Check local overrides first
|
|
if _, isOverride := p.localOverrides[hostname]; isOverride {
|
|
logger.Debug("Local override matched for hostname: %s", hostname)
|
|
return &RouteRecord{
|
|
Hostname: hostname,
|
|
TargetHost: p.localProxyAddr,
|
|
TargetPort: p.localProxyPort,
|
|
}, nil
|
|
}
|
|
|
|
// Fast path: check if hostname is in localSNIs
|
|
p.localSNIsLock.RLock()
|
|
_, isLocal := p.localSNIs[hostname]
|
|
p.localSNIsLock.RUnlock()
|
|
if isLocal {
|
|
return &RouteRecord{
|
|
Hostname: hostname,
|
|
TargetHost: p.localProxyAddr,
|
|
TargetPort: p.localProxyPort,
|
|
}, nil
|
|
}
|
|
|
|
// Check cache first
|
|
if cached, found := p.cache.Get(hostname); found {
|
|
if cached == nil {
|
|
return nil, nil // Cached negative result
|
|
}
|
|
logger.Debug("Cache hit for hostname: %s", hostname)
|
|
return cached.(*RouteRecord), nil
|
|
}
|
|
|
|
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
|
|
|
|
// Query API with timeout
|
|
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
|
|
defer cancel()
|
|
|
|
// Construct API URL (without hostname in path)
|
|
apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL)
|
|
|
|
// Create request body with hostname and public key
|
|
requestBody := map[string]string{
|
|
"hostname": hostname,
|
|
"publicKey": p.publicKey,
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(requestBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request body: %w", err)
|
|
}
|
|
|
|
// Create HTTP request
|
|
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// Make HTTP request
|
|
client := &http.Client{Timeout: 5 * time.Second}
|
|
resp, err := client.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("API request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode == http.StatusNotFound {
|
|
// Cache negative result for shorter time (1 minute)
|
|
p.cache.Set(hostname, nil, 1*time.Minute)
|
|
return nil, nil
|
|
}
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
|
|
}
|
|
|
|
// Parse response
|
|
var apiResponse RouteAPIResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
|
|
return nil, fmt.Errorf("failed to decode API response: %w", err)
|
|
}
|
|
|
|
endpoints := apiResponse.Endpoints
|
|
|
|
// Default target configuration
|
|
targetHost := p.localProxyAddr
|
|
targetPort := p.localProxyPort
|
|
|
|
// If no endpoints returned, use local node
|
|
if len(endpoints) == 0 {
|
|
logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
|
|
} else {
|
|
// Select endpoint using consistent hashing for stickiness
|
|
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
|
|
targetHost = selectedEndpoint
|
|
targetPort = 443 // Default HTTPS port
|
|
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
|
|
}
|
|
|
|
route := &RouteRecord{
|
|
Hostname: hostname,
|
|
TargetHost: targetHost,
|
|
TargetPort: targetPort,
|
|
}
|
|
|
|
// Cache the result
|
|
p.cache.Set(hostname, route, cache.DefaultExpiration)
|
|
logger.Debug("Cached route for hostname: %s", hostname)
|
|
|
|
return route, nil
|
|
}
|
|
|
|
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
|
|
// the same client always routes to the same endpoint for load balancing
|
|
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
|
|
if len(endpoints) == 0 {
|
|
return p.localProxyAddr
|
|
}
|
|
if len(endpoints) == 1 {
|
|
return endpoints[0]
|
|
}
|
|
|
|
// Use FNV hash for consistent selection based on client address
|
|
hash := fnv.New32a()
|
|
hash.Write([]byte(clientAddr))
|
|
index := hash.Sum32() % uint32(len(endpoints))
|
|
|
|
return endpoints[index]
|
|
}
|
|
|
|
// pipe handles bidirectional data transfer between connections
|
|
func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, clientReader io.Reader) {
|
|
g, ctx := errgroup.WithContext(ctx)
|
|
|
|
// Close connections when context cancels to unblock io.Copy operations
|
|
context.AfterFunc(ctx, func() {
|
|
clientConn.Close()
|
|
targetConn.Close()
|
|
})
|
|
|
|
// Copy data from client to target
|
|
g.Go(func() error {
|
|
buf := make([]byte, 32*1024)
|
|
_, err := io.CopyBuffer(targetConn, clientReader, buf)
|
|
if err != nil && err != io.EOF {
|
|
logger.Debug("Copy client->target error: %v", err)
|
|
}
|
|
return err
|
|
})
|
|
|
|
// Copy data from target to client
|
|
g.Go(func() error {
|
|
buf := make([]byte, 32*1024)
|
|
_, err := io.CopyBuffer(clientConn, targetConn, buf)
|
|
if err != nil && err != io.EOF {
|
|
logger.Debug("Copy target->client error: %v", err)
|
|
}
|
|
return err
|
|
})
|
|
|
|
g.Wait()
|
|
}
|
|
|
|
// GetCacheStats returns cache statistics
|
|
func (p *SNIProxy) GetCacheStats() (int, int) {
|
|
return p.cache.ItemCount(), len(p.cache.Items())
|
|
}
|
|
|
|
// ClearCache clears all cached entries
|
|
func (p *SNIProxy) ClearCache() {
|
|
p.cache.Flush()
|
|
log.Println("Cache cleared")
|
|
}
|
|
|
|
// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains
|
|
func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
|
|
newSNIs := make(map[string]struct{})
|
|
for _, domain := range fullDomains {
|
|
newSNIs[domain] = struct{}{}
|
|
// Invalidate any cached route for this domain
|
|
p.cache.Delete(domain)
|
|
}
|
|
|
|
// Update localSNIs
|
|
p.localSNIsLock.Lock()
|
|
removed := make([]string, 0)
|
|
for sni := range p.localSNIs {
|
|
if _, stillLocal := newSNIs[sni]; !stillLocal {
|
|
removed = append(removed, sni)
|
|
}
|
|
}
|
|
p.localSNIs = newSNIs
|
|
p.localSNIsLock.Unlock()
|
|
|
|
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
|
|
|
|
// Terminate tunnels for removed SNIs via context cancellation
|
|
if len(removed) > 0 {
|
|
p.activeTunnelsLock.Lock()
|
|
for _, sni := range removed {
|
|
if tunnel, ok := p.activeTunnels[sni]; ok {
|
|
tunnel.cancel()
|
|
delete(p.activeTunnels, sni)
|
|
logger.Debug("Cancelled tunnel context for SNI target change: %s", sni)
|
|
}
|
|
}
|
|
p.activeTunnelsLock.Unlock()
|
|
}
|
|
}
|