Files
gerbil/proxy/proxy.go

846 lines
24 KiB
Go

package proxy
import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
"fmt"
"hash/fnv"
"io"
"log"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/fosrl/gerbil/logger"
"github.com/patrickmn/go-cache"
)
// RouteRecord represents a routing configuration
type RouteRecord struct {
Hostname string
TargetHost string
TargetPort int
}
// RouteAPIResponse represents the response from the route API
type RouteAPIResponse struct {
Endpoints []string `json:"endpoints"`
}
// ProxyProtocolInfo holds information parsed from incoming PROXY protocol header
type ProxyProtocolInfo struct {
Protocol string // TCP4 or TCP6
SrcIP string
DestIP string
SrcPort int
DestPort int
OriginalConn net.Conn // The original connection after PROXY protocol parsing
}
// SNIProxy represents the main proxy server
type SNIProxy struct {
port int
cache *cache.Cache
listener net.Listener
ctx context.Context
cancel context.CancelFunc
wg sync.WaitGroup
localProxyAddr string
localProxyPort int
remoteConfigURL string
publicKey string
proxyProtocol bool // Enable PROXY protocol v1
// New fields for fast local SNI lookup
localSNIs map[string]struct{}
localSNIsLock sync.RWMutex
// Local overrides for domains that should always use local proxy
localOverrides map[string]struct{}
// Track active tunnels by SNI
activeTunnels map[string]*activeTunnel
activeTunnelsLock sync.Mutex
// Trusted upstream proxies that can send PROXY protocol
trustedUpstreams map[string]struct{}
}
type activeTunnel struct {
conns []net.Conn
}
// readOnlyConn is a wrapper for io.Reader that implements net.Conn
type readOnlyConn struct {
reader io.Reader
}
func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) }
func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe }
func (conn readOnlyConn) Close() error { return nil }
func (conn readOnlyConn) LocalAddr() net.Addr { return nil }
func (conn readOnlyConn) RemoteAddr() net.Addr { return nil }
func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil }
func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil }
// parseProxyProtocolHeader parses a PROXY protocol v1 header from the connection
func (p *SNIProxy) parseProxyProtocolHeader(conn net.Conn) (*ProxyProtocolInfo, net.Conn, error) {
// Check if the connection comes from a trusted upstream
remoteHost, _, err := net.SplitHostPort(conn.RemoteAddr().String())
if err != nil {
return nil, conn, fmt.Errorf("failed to parse remote address: %w", err)
}
// Resolve the remote IP to hostname to check if it's trusted
// For simplicity, we'll check the IP directly in trusted upstreams
// In production, you might want to do reverse DNS lookup
if _, isTrusted := p.trustedUpstreams[remoteHost]; !isTrusted {
// Not from trusted upstream, return original connection
return nil, conn, nil
}
// Set read timeout for PROXY protocol parsing
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, conn, fmt.Errorf("failed to set read deadline: %w", err)
}
// Read the first line (PROXY protocol header)
buffer := make([]byte, 512) // PROXY protocol header should be much smaller
n, err := conn.Read(buffer)
if err != nil {
// If we can't read from trusted upstream, treat as regular connection
logger.Debug("Could not read from trusted upstream %s, treating as regular connection: %v", remoteHost, err)
// Clear read timeout before returning
if clearErr := conn.SetReadDeadline(time.Time{}); clearErr != nil {
logger.Debug("Failed to clear read deadline: %v", clearErr)
}
return nil, conn, nil
}
// Find the end of the first line (CRLF)
headerEnd := bytes.Index(buffer[:n], []byte("\r\n"))
if headerEnd == -1 {
// No PROXY protocol header found, treat as regular TLS connection
// Return the connection with the buffered data prepended
logger.Debug("No PROXY protocol header from trusted upstream %s, treating as regular TLS connection", remoteHost)
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
}
// Create a reader that includes the buffered data + original connection
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
return nil, wrappedConn, nil
}
headerLine := string(buffer[:headerEnd])
remainingData := buffer[headerEnd+2 : n]
// Parse PROXY protocol line: "PROXY TCP4/TCP6 srcIP destIP srcPort destPort"
parts := strings.Fields(headerLine)
if len(parts) != 6 || parts[0] != "PROXY" {
// Check for PROXY UNKNOWN
if len(parts) == 2 && parts[0] == "PROXY" && parts[1] == "UNKNOWN" {
// PROXY UNKNOWN - use original connection info
return nil, conn, nil
}
// Invalid PROXY protocol, but might be regular TLS - treat as such
logger.Debug("Invalid PROXY protocol from trusted upstream %s, treating as regular TLS connection: %s", remoteHost, headerLine)
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
}
// Return the connection with all buffered data prepended
newReader := io.MultiReader(bytes.NewReader(buffer[:n]), conn)
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
return nil, wrappedConn, nil
}
protocol := parts[1]
srcIP := parts[2]
destIP := parts[3]
srcPort, err := strconv.Atoi(parts[4])
if err != nil {
return nil, conn, fmt.Errorf("invalid source port in PROXY header: %s", parts[4])
}
destPort, err := strconv.Atoi(parts[5])
if err != nil {
return nil, conn, fmt.Errorf("invalid destination port in PROXY header: %s", parts[5])
}
// Create a new reader that includes remaining data + original connection
var newReader io.Reader
if len(remainingData) > 0 {
newReader = io.MultiReader(bytes.NewReader(remainingData), conn)
} else {
newReader = conn
}
// Create a wrapper connection that reads from the combined reader
wrappedConn := &proxyProtocolConn{
Conn: conn,
reader: newReader,
}
proxyInfo := &ProxyProtocolInfo{
Protocol: protocol,
SrcIP: srcIP,
DestIP: destIP,
SrcPort: srcPort,
DestPort: destPort,
OriginalConn: wrappedConn,
}
// Clear read timeout
if err := conn.SetReadDeadline(time.Time{}); err != nil {
return nil, conn, fmt.Errorf("failed to clear read deadline: %w", err)
}
return proxyInfo, wrappedConn, nil
}
// proxyProtocolConn wraps a connection to read from a custom reader
type proxyProtocolConn struct {
net.Conn
reader io.Reader
}
func (c *proxyProtocolConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
// buildProxyProtocolHeaderFromInfo creates a PROXY protocol v1 header using ProxyProtocolInfo
func (p *SNIProxy) buildProxyProtocolHeaderFromInfo(proxyInfo *ProxyProtocolInfo, targetAddr net.Addr) string {
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Use the original client information from the PROXY protocol
var targetIP string
var protocol string
// Parse source IP to determine protocol family
srcIP := net.ParseIP(proxyInfo.SrcIP)
if srcIP == nil {
return "PROXY UNKNOWN\r\n"
}
if srcIP.To4() != nil {
// Source is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Source is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
proxyInfo.SrcIP,
targetIP,
proxyInfo.SrcPort,
targetTCP.Port)
}
// buildProxyProtocolHeader creates a PROXY protocol v1 header
func buildProxyProtocolHeader(clientAddr, targetAddr net.Addr) string {
clientTCP, ok := clientAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
targetTCP, ok := targetAddr.(*net.TCPAddr)
if !ok {
// Fallback for unknown address types
return "PROXY UNKNOWN\r\n"
}
// Determine protocol family based on client IP and normalize target IP accordingly
var protocol string
var targetIP string
if clientTCP.IP.To4() != nil {
// Client is IPv4, use TCP4 protocol
protocol = "TCP4"
if targetTCP.IP.To4() != nil {
// Target is also IPv4, use as-is
targetIP = targetTCP.IP.String()
} else {
// Target is IPv6, but we need IPv4 for consistent protocol family
// Use the IPv4 loopback if target is IPv6 loopback, otherwise use 127.0.0.1
if targetTCP.IP.IsLoopback() {
targetIP = "127.0.0.1"
} else {
// For non-loopback IPv6 targets, we could try to extract embedded IPv4
// or fall back to a sensible IPv4 address based on the target
targetIP = "127.0.0.1" // Safe fallback
}
}
} else {
// Client is IPv6, use TCP6 protocol
protocol = "TCP6"
if targetTCP.IP.To4() != nil {
// Target is IPv4, convert to IPv6 representation
targetIP = "::ffff:" + targetTCP.IP.String()
} else {
// Target is also IPv6, use as-is
targetIP = targetTCP.IP.String()
}
}
return fmt.Sprintf("PROXY %s %s %s %d %d\r\n",
protocol,
clientTCP.IP.String(),
targetIP,
clientTCP.Port,
targetTCP.Port)
}
// NewSNIProxy creates a new SNI proxy instance
func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string, proxyProtocol bool, trustedUpstreams []string) (*SNIProxy, error) {
ctx, cancel := context.WithCancel(context.Background())
// Create local overrides map
overridesMap := make(map[string]struct{})
for _, domain := range localOverrides {
if domain != "" {
overridesMap[domain] = struct{}{}
}
}
// Create trusted upstreams map
trustedMap := make(map[string]struct{})
for _, upstream := range trustedUpstreams {
if upstream != "" {
// Add both the domain and potentially resolved IPs
trustedMap[upstream] = struct{}{}
// Try to resolve the domain to IPs and add them too
if ips, err := net.LookupIP(upstream); err == nil {
for _, ip := range ips {
trustedMap[ip.String()] = struct{}{}
}
}
}
}
proxy := &SNIProxy{
port: port,
cache: cache.New(3*time.Second, 10*time.Minute),
ctx: ctx,
cancel: cancel,
localProxyAddr: localProxyAddr,
localProxyPort: localProxyPort,
remoteConfigURL: remoteConfigURL,
publicKey: publicKey,
proxyProtocol: proxyProtocol,
localSNIs: make(map[string]struct{}),
localOverrides: overridesMap,
activeTunnels: make(map[string]*activeTunnel),
trustedUpstreams: trustedMap,
}
return proxy, nil
}
// Start begins listening for connections
func (p *SNIProxy) Start() error {
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
if err != nil {
return fmt.Errorf("failed to listen on port %d: %w", p.port, err)
}
p.listener = listener
logger.Debug("SNI Proxy listening on port %d", p.port)
// Accept connections in a goroutine
go p.acceptConnections()
return nil
}
// Stop gracefully shuts down the proxy
func (p *SNIProxy) Stop() error {
log.Println("Stopping SNI Proxy...")
p.cancel()
if p.listener != nil {
p.listener.Close()
}
// Wait for all goroutines to finish with timeout
done := make(chan struct{})
go func() {
p.wg.Wait()
close(done)
}()
select {
case <-done:
log.Println("All connections closed gracefully")
case <-time.After(30 * time.Second):
log.Println("Timeout waiting for connections to close")
}
log.Println("SNI Proxy stopped")
return nil
}
// acceptConnections handles incoming connections
func (p *SNIProxy) acceptConnections() {
for {
conn, err := p.listener.Accept()
if err != nil {
select {
case <-p.ctx.Done():
return
default:
logger.Debug("Accept error: %v", err)
continue
}
}
p.wg.Add(1)
go p.handleConnection(conn)
}
}
// readClientHello reads and parses the TLS ClientHello message
func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) {
var hello *tls.ClientHelloInfo
err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{
GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) {
hello = new(tls.ClientHelloInfo)
*hello = *argHello
return nil, nil
},
}).Handshake()
if hello == nil {
return nil, err
}
return hello, nil
}
// peekClientHello reads the ClientHello while preserving the data for forwarding
func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) {
peekedBytes := new(bytes.Buffer)
hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes))
if err != nil {
return nil, nil, err
}
return hello, io.MultiReader(peekedBytes, reader), nil
}
// extractSNI extracts the SNI hostname from the TLS ClientHello
func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) {
clientHello, clientReader, err := p.peekClientHello(conn)
if err != nil {
return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err)
}
if clientHello.ServerName == "" {
return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello")
}
return clientHello.ServerName, clientReader, nil
}
// handleConnection processes a single client connection
func (p *SNIProxy) handleConnection(clientConn net.Conn) {
defer p.wg.Done()
defer clientConn.Close()
logger.Debug("Accepted connection from %s", clientConn.RemoteAddr())
// Check for PROXY protocol from trusted upstream
var proxyInfo *ProxyProtocolInfo
var actualClientConn net.Conn = clientConn
if len(p.trustedUpstreams) > 0 {
var err error
proxyInfo, actualClientConn, err = p.parseProxyProtocolHeader(clientConn)
if err != nil {
logger.Debug("Failed to parse PROXY protocol: %v", err)
return
}
if proxyInfo != nil {
logger.Debug("Received PROXY protocol from trusted upstream: %s:%d -> %s:%d",
proxyInfo.SrcIP, proxyInfo.SrcPort, proxyInfo.DestIP, proxyInfo.DestPort)
} else {
// No PROXY protocol detected, but connection is from trusted upstream
// This is fine - treat as regular connection
logger.Debug("No PROXY protocol detected from trusted upstream, treating as regular connection")
}
}
// Set read timeout for SNI extraction
if err := actualClientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
logger.Debug("Failed to set read deadline: %v", err)
return
}
// Extract SNI hostname
hostname, clientReader, err := p.extractSNI(actualClientConn)
if err != nil {
logger.Debug("SNI extraction failed: %v", err)
return
}
if hostname == "" {
log.Println("No SNI hostname found")
return
}
logger.Debug("SNI hostname detected: %s", hostname)
// Remove read timeout for normal operation
if err := actualClientConn.SetReadDeadline(time.Time{}); err != nil {
logger.Debug("Failed to clear read deadline: %v", err)
return
}
// Get routing information - use original client address if available from PROXY protocol
var clientAddrStr string
if proxyInfo != nil {
clientAddrStr = fmt.Sprintf("%s:%d", proxyInfo.SrcIP, proxyInfo.SrcPort)
} else {
clientAddrStr = clientConn.RemoteAddr().String()
}
route, err := p.getRoute(hostname, clientAddrStr)
if err != nil {
logger.Debug("Failed to get route for %s: %v", hostname, err)
return
}
if route == nil {
logger.Debug("No route found for hostname: %s", hostname)
return
}
logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort)
// Connect to target server
targetConn, err := net.DialTimeout("tcp",
fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort),
10*time.Second)
if err != nil {
logger.Debug("Failed to connect to target %s:%d: %v",
route.TargetHost, route.TargetPort, err)
return
}
defer targetConn.Close()
logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort)
// Send PROXY protocol header if enabled
if p.proxyProtocol {
var proxyHeader string
if proxyInfo != nil {
// Use original client info from PROXY protocol
proxyHeader = p.buildProxyProtocolHeaderFromInfo(proxyInfo, targetConn.LocalAddr())
} else {
// Use direct client connection info
proxyHeader = buildProxyProtocolHeader(clientConn.RemoteAddr(), targetConn.LocalAddr())
}
logger.Debug("Sending PROXY protocol header: %s", strings.TrimSpace(proxyHeader))
if _, err := targetConn.Write([]byte(proxyHeader)); err != nil {
logger.Debug("Failed to send PROXY protocol header: %v", err)
return
}
}
// Track this tunnel by SNI
p.activeTunnelsLock.Lock()
tunnel, ok := p.activeTunnels[hostname]
if !ok {
tunnel = &activeTunnel{}
p.activeTunnels[hostname] = tunnel
}
tunnel.conns = append(tunnel.conns, actualClientConn)
p.activeTunnelsLock.Unlock()
defer func() {
// Remove this conn from active tunnels
p.activeTunnelsLock.Lock()
if tunnel, ok := p.activeTunnels[hostname]; ok {
newConns := make([]net.Conn, 0, len(tunnel.conns))
for _, c := range tunnel.conns {
if c != actualClientConn {
newConns = append(newConns, c)
}
}
if len(newConns) == 0 {
delete(p.activeTunnels, hostname)
} else {
tunnel.conns = newConns
}
}
p.activeTunnelsLock.Unlock()
}()
// Start bidirectional data transfer
p.pipe(actualClientConn, targetConn, clientReader)
}
// getRoute retrieves routing information for a hostname
func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) {
// Check local overrides first
if _, isOverride := p.localOverrides[hostname]; isOverride {
logger.Debug("Local override matched for hostname: %s", hostname)
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
TargetPort: p.localProxyPort,
}, nil
}
// Fast path: check if hostname is in localSNIs
p.localSNIsLock.RLock()
_, isLocal := p.localSNIs[hostname]
p.localSNIsLock.RUnlock()
if isLocal {
return &RouteRecord{
Hostname: hostname,
TargetHost: p.localProxyAddr,
TargetPort: p.localProxyPort,
}, nil
}
// Check cache first
if cached, found := p.cache.Get(hostname); found {
if cached == nil {
return nil, nil // Cached negative result
}
logger.Debug("Cache hit for hostname: %s", hostname)
return cached.(*RouteRecord), nil
}
logger.Debug("Cache miss for hostname: %s, querying API", hostname)
// Query API with timeout
ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second)
defer cancel()
// Construct API URL (without hostname in path)
apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL)
// Create request body with hostname and public key
requestBody := map[string]string{
"hostname": hostname,
"publicKey": p.publicKey,
}
jsonBody, err := json.Marshal(requestBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
// Create HTTP request
req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
// Make HTTP request
client := &http.Client{Timeout: 5 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("API request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
// Cache negative result for shorter time (1 minute)
p.cache.Set(hostname, nil, 1*time.Minute)
return nil, nil
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API returned status %d", resp.StatusCode)
}
// Parse response
var apiResponse RouteAPIResponse
if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil {
return nil, fmt.Errorf("failed to decode API response: %w", err)
}
endpoints := apiResponse.Endpoints
// Default target configuration
targetHost := p.localProxyAddr
targetPort := p.localProxyPort
// If no endpoints returned, use local node
if len(endpoints) == 0 {
logger.Debug("No endpoints returned for hostname: %s, using local node", hostname)
} else {
// Select endpoint using consistent hashing for stickiness
selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints)
targetHost = selectedEndpoint
targetPort = 443 // Default HTTPS port
logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr)
}
route := &RouteRecord{
Hostname: hostname,
TargetHost: targetHost,
TargetPort: targetPort,
}
// Cache the result
p.cache.Set(hostname, route, cache.DefaultExpiration)
logger.Debug("Cached route for hostname: %s", hostname)
return route, nil
}
// selectStickyEndpoint selects an endpoint using consistent hashing to ensure
// the same client always routes to the same endpoint for load balancing
func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string {
if len(endpoints) == 0 {
return p.localProxyAddr
}
if len(endpoints) == 1 {
return endpoints[0]
}
// Use FNV hash for consistent selection based on client address
hash := fnv.New32a()
hash.Write([]byte(clientAddr))
index := hash.Sum32() % uint32(len(endpoints))
return endpoints[index]
}
// pipe handles bidirectional data transfer between connections
func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) {
var wg sync.WaitGroup
wg.Add(2)
// Copy data from client to target (using the buffered reader)
go func() {
defer wg.Done()
defer func() {
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// Use a large buffer for better performance
buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(targetConn, clientReader, buf)
if err != nil && err != io.EOF {
logger.Debug("Copy client->target error: %v", err)
}
}()
// Copy data from target to client
go func() {
defer wg.Done()
defer func() {
if tcpConn, ok := clientConn.(*net.TCPConn); ok {
tcpConn.CloseWrite()
}
}()
// Use a large buffer for better performance
buf := make([]byte, 32*1024)
_, err := io.CopyBuffer(clientConn, targetConn, buf)
if err != nil && err != io.EOF {
logger.Debug("Copy target->client error: %v", err)
}
}()
wg.Wait()
}
// GetCacheStats returns cache statistics
func (p *SNIProxy) GetCacheStats() (int, int) {
return p.cache.ItemCount(), len(p.cache.Items())
}
// ClearCache clears all cached entries
func (p *SNIProxy) ClearCache() {
p.cache.Flush()
log.Println("Cache cleared")
}
// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains
func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) {
newSNIs := make(map[string]struct{})
for _, domain := range fullDomains {
newSNIs[domain] = struct{}{}
// Invalidate any cached route for this domain
p.cache.Delete(domain)
}
// Update localSNIs
p.localSNIsLock.Lock()
removed := make([]string, 0)
for sni := range p.localSNIs {
if _, stillLocal := newSNIs[sni]; !stillLocal {
removed = append(removed, sni)
}
}
p.localSNIs = newSNIs
p.localSNIsLock.Unlock()
logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed))
// Terminate tunnels for removed SNIs
if len(removed) > 0 {
p.activeTunnelsLock.Lock()
for _, sni := range removed {
if tunnels, ok := p.activeTunnels[sni]; ok {
for _, conn := range tunnels.conns {
conn.Close()
}
delete(p.activeTunnels, sni)
logger.Debug("Closed tunnels for SNI target change: %s", sni)
}
}
p.activeTunnelsLock.Unlock()
}
}