diff --git a/README.md b/README.md index 830085b..54c3b55 100644 --- a/README.md +++ b/README.md @@ -62,8 +62,7 @@ Example: ./gerbil \ --reachableAt=http://gerbil:3003 \ --generateAndSaveKeyTo=/var/config/key \ ---remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config \ ---reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth +--remoteConfig=http://pangolin:3001/api/v1/ ``` ```yaml @@ -75,8 +74,7 @@ services: command: - --reachableAt=http://gerbil:3003 - --generateAndSaveKeyTo=/var/config/key - - --remoteConfig=http://pangolin:3001/api/v1/gerbil/get-config - - --reportBandwidthTo=http://pangolin:3001/api/v1/gerbil/receive-bandwidth + - --remoteConfig=http://pangolin:3001/api/v1/ volumes: - ./config/:/var/config cap_add: diff --git a/go.mod b/go.mod index 46d796e..c17865a 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/mdlayher/genetlink v1.3.2 // indirect github.com/mdlayher/netlink v1.7.2 // indirect github.com/mdlayher/socket v0.4.1 // indirect + github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/vishvananda/netns v0.0.5 // indirect golang.org/x/crypto v0.36.0 // indirect golang.org/x/net v0.38.0 // indirect diff --git a/go.sum b/go.sum index 9a60eee..a089b8e 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/mdlayher/socket v0.4.1 h1:eM9y2/jlbs1M615oshPQOHZzj6R6wMT7bX5NPiQvn2U github.com/mdlayher/socket v0.4.1/go.mod h1:cAqeGjoufqdxWkD7DkpyS+wcefOtmu5OQ8KuoJGIReA= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= +github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= +github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= diff --git a/main.go b/main.go index e2d2c16..37350ce 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ import ( "time" "github.com/fosrl/gerbil/logger" + "github.com/fosrl/gerbil/proxy" "github.com/fosrl/gerbil/relay" "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/wgctrl" @@ -32,7 +33,8 @@ var ( mu sync.Mutex wgMu sync.Mutex // Protects WireGuard operations notifyURL string - proxyServer *relay.UDPProxyServer + proxyRelay *relay.UDPProxyServer + proxySNI *proxy.SNIProxy ) type WgConfig struct { @@ -115,6 +117,10 @@ func main() { reachableAt string logLevel string mtu string + sniProxyPort int + localProxyAddr string + localProxyPort int + localOverridesStr string ) interfaceName = os.Getenv("INTERFACE") @@ -127,6 +133,11 @@ func main() { mtu = os.Getenv("MTU") notifyURL = os.Getenv("NOTIFY_URL") + sniProxyPortStr := os.Getenv("SNI_PORT") + localProxyAddr = os.Getenv("LOCAL_PROXY") + localProxyPortStr := os.Getenv("LOCAL_PROXY_PORT") + localOverridesStr = os.Getenv("LOCAL_OVERRIDES") + if interfaceName == "" { flag.StringVar(&interfaceName, "interface", "wg0", "Name of the WireGuard interface") } @@ -159,6 +170,32 @@ func main() { if notifyURL == "" { flag.StringVar(¬ifyURL, "notify", "", "URL to notify on peer changes") } + + if sniProxyPortStr != "" { + if port, err := strconv.Atoi(sniProxyPortStr); err == nil { + sniProxyPort = port + } + } + if sniProxyPortStr == "" { + flag.IntVar(&sniProxyPort, "sni-port", 8443, "Port to listen on") + } + + if localProxyAddr == "" { + flag.StringVar(&localProxyAddr, "local-proxy", "localhost", "Local proxy address") + } + + if localProxyPortStr != "" { + if port, err := strconv.Atoi(localProxyPortStr); err == nil { + localProxyPort = port + } + } + if localProxyPortStr == "" { + flag.IntVar(&localProxyPort, "local-proxy-port", 443, "Local proxy port") + } + if localOverridesStr != "" { + flag.StringVar(&localOverridesStr, "local-overrides", "", "Comma-separated list of local overrides for SNI proxy") + } + flag.Parse() logger.Init() @@ -258,17 +295,39 @@ func main() { go periodicBandwidthCheck(remoteConfigURL + "/gerbil/receive-bandwidth") // Start the UDP proxy server - proxyServer = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) - err = proxyServer.Start() + proxyRelay = relay.NewUDPProxyServer(":21820", remoteConfigURL, key, reachableAt) + err = proxyRelay.Start() if err != nil { logger.Fatal("Failed to start UDP proxy server: %v", err) } - defer proxyServer.Stop() + defer proxyRelay.Stop() + + // TODO: WE SHOULD PULL THIS OUT OF THE CONFIG OR SOMETHING + // SO YOU DON'T NEED TO SET THIS SEPARATELY + // Parse local overrides + var localOverrides []string + if localOverridesStr != "" { + localOverrides = strings.Split(localOverridesStr, ",") + for i, domain := range localOverrides { + localOverrides[i] = strings.TrimSpace(domain) + } + logger.Info("Local overrides configured: %v", localOverrides) + } + + proxySNI, err = proxy.NewSNIProxy(sniProxyPort, remoteConfigURL, key.PublicKey().String(), localProxyAddr, localProxyPort, localOverrides) + if err != nil { + logger.Fatal("Failed to create proxy: %v", err) + } + + if err := proxySNI.Start(); err != nil { + logger.Fatal("Failed to start proxy: %v", err) + } // Set up HTTP server http.HandleFunc("/peer", handlePeer) http.HandleFunc("/update-proxy-mapping", handleUpdateProxyMapping) http.HandleFunc("/update-destinations", handleUpdateDestinations) + http.HandleFunc("/update-local-snis", handleUpdateLocalSNIs) logger.Info("Starting HTTP server on %s", listenAddr) // Run HTTP server in a goroutine @@ -647,9 +706,9 @@ func addPeerInternal(peer Peer) error { } // Clear relay connections for the peer's WireGuard IPs - if proxyServer != nil { + if proxyRelay != nil { for _, wgIP := range wgIPs { - proxyServer.OnPeerAdded(wgIP) + proxyRelay.OnPeerAdded(wgIP) } } @@ -692,7 +751,7 @@ func removePeerInternal(publicKey string) error { // Get current peer info before removing to clear relay connections var wgIPs []string - if proxyServer != nil { + if proxyRelay != nil { device, err := wgClient.Device(interfaceName) if err == nil { for _, peer := range device.Peers { @@ -721,9 +780,9 @@ func removePeerInternal(publicKey string) error { } // Clear relay connections for the peer's WireGuard IPs - if proxyServer != nil { + if proxyRelay != nil { for _, wgIP := range wgIPs { - proxyServer.OnPeerRemoved(wgIP) + proxyRelay.OnPeerRemoved(wgIP) } } @@ -760,13 +819,13 @@ func handleUpdateProxyMapping(w http.ResponseWriter, r *http.Request) { } // Update the proxy mappings in the relay server - if proxyServer == nil { + if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) return } - updatedCount := proxyServer.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) + updatedCount := proxyRelay.UpdateDestinationInMappings(update.OldDestination, update.NewDestination) logger.Info("Updated %d proxy mappings: %s:%d -> %s:%d", updatedCount, @@ -830,13 +889,13 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { } // Update the proxy mappings in the relay server - if proxyServer == nil { + if proxyRelay == nil { logger.Error("Proxy server is not available") http.Error(w, "Proxy server is not available", http.StatusInternalServerError) return } - proxyServer.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) + proxyRelay.UpdateProxyMapping(request.SourceIP, request.SourcePort, request.Destinations) logger.Info("Updated proxy mapping for %s:%d with %d destinations", request.SourceIP, request.SourcePort, len(request.Destinations)) @@ -851,6 +910,32 @@ func handleUpdateDestinations(w http.ResponseWriter, r *http.Request) { }) } +// UpdateLocalSNIsRequest represents the JSON payload for updating local SNIs +type UpdateLocalSNIsRequest struct { + FullDomains []string `json:"fullDomains"` +} + +func handleUpdateLocalSNIs(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + logger.Error("Invalid method: %s", r.Method) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + var req UpdateLocalSNIsRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "Invalid JSON payload", http.StatusBadRequest) + return + } + + proxySNI.UpdateLocalSNIs(req.FullDomains) + + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "Local SNIs updated successfully", + }) +} + func periodicBandwidthCheck(endpoint string) { ticker := time.NewTicker(10 * time.Second) defer ticker.Stop() diff --git a/proxy/proxy.go b/proxy/proxy.go new file mode 100644 index 0000000..7edd92e --- /dev/null +++ b/proxy/proxy.go @@ -0,0 +1,522 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/tls" + "encoding/json" + "fmt" + "hash/fnv" + "io" + "log" + "net" + "net/http" + "sync" + "time" + + "github.com/fosrl/gerbil/logger" + "github.com/patrickmn/go-cache" +) + +// RouteRecord represents a routing configuration +type RouteRecord struct { + Hostname string + TargetHost string + TargetPort int +} + +// RouteAPIResponse represents the response from the route API +type RouteAPIResponse struct { + Endpoints []string `json:"endpoints"` +} + +// SNIProxy represents the main proxy server +type SNIProxy struct { + port int + cache *cache.Cache + listener net.Listener + ctx context.Context + cancel context.CancelFunc + wg sync.WaitGroup + localProxyAddr string + localProxyPort int + remoteConfigURL string + publicKey string + + // New fields for fast local SNI lookup + localSNIs map[string]struct{} + localSNIsLock sync.RWMutex + + // Local overrides for domains that should always use local proxy + localOverrides map[string]struct{} + + // Track active tunnels by SNI + activeTunnels map[string]*activeTunnel + activeTunnelsLock sync.Mutex +} + +type activeTunnel struct { + conns []net.Conn +} + +// readOnlyConn is a wrapper for io.Reader that implements net.Conn +type readOnlyConn struct { + reader io.Reader +} + +func (conn readOnlyConn) Read(p []byte) (int, error) { return conn.reader.Read(p) } +func (conn readOnlyConn) Write(p []byte) (int, error) { return 0, io.ErrClosedPipe } +func (conn readOnlyConn) Close() error { return nil } +func (conn readOnlyConn) LocalAddr() net.Addr { return nil } +func (conn readOnlyConn) RemoteAddr() net.Addr { return nil } +func (conn readOnlyConn) SetDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetReadDeadline(t time.Time) error { return nil } +func (conn readOnlyConn) SetWriteDeadline(t time.Time) error { return nil } + +// NewSNIProxy creates a new SNI proxy instance +func NewSNIProxy(port int, remoteConfigURL, publicKey, localProxyAddr string, localProxyPort int, localOverrides []string) (*SNIProxy, error) { + ctx, cancel := context.WithCancel(context.Background()) + + // Create local overrides map + overridesMap := make(map[string]struct{}) + for _, domain := range localOverrides { + if domain != "" { + overridesMap[domain] = struct{}{} + } + } + + proxy := &SNIProxy{ + port: port, + cache: cache.New(3*time.Second, 10*time.Minute), + ctx: ctx, + cancel: cancel, + localProxyAddr: localProxyAddr, + localProxyPort: localProxyPort, + remoteConfigURL: remoteConfigURL, + publicKey: publicKey, + localSNIs: make(map[string]struct{}), + localOverrides: overridesMap, + activeTunnels: make(map[string]*activeTunnel), + } + + return proxy, nil +} + +// Start begins listening for connections +func (p *SNIProxy) Start() error { + listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port)) + if err != nil { + return fmt.Errorf("failed to listen on port %d: %w", p.port, err) + } + + p.listener = listener + logger.Debug("SNI Proxy listening on port %d", p.port) + + // Accept connections in a goroutine + go p.acceptConnections() + + return nil +} + +// Stop gracefully shuts down the proxy +func (p *SNIProxy) Stop() error { + log.Println("Stopping SNI Proxy...") + + p.cancel() + + if p.listener != nil { + p.listener.Close() + } + + // Wait for all goroutines to finish with timeout + done := make(chan struct{}) + go func() { + p.wg.Wait() + close(done) + }() + + select { + case <-done: + log.Println("All connections closed gracefully") + case <-time.After(30 * time.Second): + log.Println("Timeout waiting for connections to close") + } + + log.Println("SNI Proxy stopped") + return nil +} + +// acceptConnections handles incoming connections +func (p *SNIProxy) acceptConnections() { + for { + conn, err := p.listener.Accept() + if err != nil { + select { + case <-p.ctx.Done(): + return + default: + logger.Debug("Accept error: %v", err) + continue + } + } + + p.wg.Add(1) + go p.handleConnection(conn) + } +} + +// readClientHello reads and parses the TLS ClientHello message +func (p *SNIProxy) readClientHello(reader io.Reader) (*tls.ClientHelloInfo, error) { + var hello *tls.ClientHelloInfo + err := tls.Server(readOnlyConn{reader: reader}, &tls.Config{ + GetConfigForClient: func(argHello *tls.ClientHelloInfo) (*tls.Config, error) { + hello = new(tls.ClientHelloInfo) + *hello = *argHello + return nil, nil + }, + }).Handshake() + if hello == nil { + return nil, err + } + return hello, nil +} + +// peekClientHello reads the ClientHello while preserving the data for forwarding +func (p *SNIProxy) peekClientHello(reader io.Reader) (*tls.ClientHelloInfo, io.Reader, error) { + peekedBytes := new(bytes.Buffer) + hello, err := p.readClientHello(io.TeeReader(reader, peekedBytes)) + if err != nil { + return nil, nil, err + } + return hello, io.MultiReader(peekedBytes, reader), nil +} + +// extractSNI extracts the SNI hostname from the TLS ClientHello +func (p *SNIProxy) extractSNI(conn net.Conn) (string, io.Reader, error) { + clientHello, clientReader, err := p.peekClientHello(conn) + if err != nil { + return "", nil, fmt.Errorf("failed to peek ClientHello: %w", err) + } + + if clientHello.ServerName == "" { + return "", clientReader, fmt.Errorf("no SNI hostname found in ClientHello") + } + + return clientHello.ServerName, clientReader, nil +} + +// handleConnection processes a single client connection +func (p *SNIProxy) handleConnection(clientConn net.Conn) { + defer p.wg.Done() + defer clientConn.Close() + + logger.Debug("Accepted connection from %s", clientConn.RemoteAddr()) + + // Set read timeout for SNI extraction + if err := clientConn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + logger.Debug("Failed to set read deadline: %v", err) + return + } + + // Extract SNI hostname + hostname, clientReader, err := p.extractSNI(clientConn) + if err != nil { + logger.Debug("SNI extraction failed: %v", err) + return + } + + if hostname == "" { + log.Println("No SNI hostname found") + return + } + + logger.Debug("SNI hostname detected: %s", hostname) + + // Remove read timeout for normal operation + if err := clientConn.SetReadDeadline(time.Time{}); err != nil { + logger.Debug("Failed to clear read deadline: %v", err) + return + } + + // Get routing information + route, err := p.getRoute(hostname, clientConn.RemoteAddr().String()) + if err != nil { + logger.Debug("Failed to get route for %s: %v", hostname, err) + return + } + + if route == nil { + logger.Debug("No route found for hostname: %s", hostname) + return + } + + logger.Debug("Routing %s to %s:%d", hostname, route.TargetHost, route.TargetPort) + + // Connect to target server + targetConn, err := net.DialTimeout("tcp", + fmt.Sprintf("%s:%d", route.TargetHost, route.TargetPort), + 10*time.Second) + if err != nil { + logger.Debug("Failed to connect to target %s:%d: %v", + route.TargetHost, route.TargetPort, err) + return + } + defer targetConn.Close() + + logger.Debug("Connected to target: %s:%d", route.TargetHost, route.TargetPort) + + // Track this tunnel by SNI + p.activeTunnelsLock.Lock() + tunnel, ok := p.activeTunnels[hostname] + if !ok { + tunnel = &activeTunnel{} + p.activeTunnels[hostname] = tunnel + } + tunnel.conns = append(tunnel.conns, clientConn) + p.activeTunnelsLock.Unlock() + + defer func() { + // Remove this conn from active tunnels + p.activeTunnelsLock.Lock() + if tunnel, ok := p.activeTunnels[hostname]; ok { + newConns := make([]net.Conn, 0, len(tunnel.conns)) + for _, c := range tunnel.conns { + if c != clientConn { + newConns = append(newConns, c) + } + } + if len(newConns) == 0 { + delete(p.activeTunnels, hostname) + } else { + tunnel.conns = newConns + } + } + p.activeTunnelsLock.Unlock() + }() + + // Start bidirectional data transfer + p.pipe(clientConn, targetConn, clientReader) +} + +// getRoute retrieves routing information for a hostname +func (p *SNIProxy) getRoute(hostname, clientAddr string) (*RouteRecord, error) { + // Check local overrides first + if _, isOverride := p.localOverrides[hostname]; isOverride { + logger.Debug("Local override matched for hostname: %s", hostname) + return &RouteRecord{ + Hostname: hostname, + TargetHost: p.localProxyAddr, + TargetPort: p.localProxyPort, + }, nil + } + + // Fast path: check if hostname is in localSNIs + p.localSNIsLock.RLock() + _, isLocal := p.localSNIs[hostname] + p.localSNIsLock.RUnlock() + if isLocal { + return &RouteRecord{ + Hostname: hostname, + TargetHost: p.localProxyAddr, + TargetPort: p.localProxyPort, + }, nil + } + + // Check cache first + if cached, found := p.cache.Get(hostname); found { + if cached == nil { + return nil, nil // Cached negative result + } + logger.Debug("Cache hit for hostname: %s", hostname) + return cached.(*RouteRecord), nil + } + + logger.Debug("Cache miss for hostname: %s, querying API", hostname) + + // Query API with timeout + ctx, cancel := context.WithTimeout(p.ctx, 5*time.Second) + defer cancel() + + // Construct API URL (without hostname in path) + apiURL := fmt.Sprintf("%s/gerbil/get-resolved-hostname", p.remoteConfigURL) + + // Create request body with hostname and public key + requestBody := map[string]string{ + "hostname": hostname, + "publicKey": p.publicKey, + } + + jsonBody, err := json.Marshal(requestBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "POST", apiURL, bytes.NewBuffer(jsonBody)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + // Make HTTP request + client := &http.Client{Timeout: 5 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("API request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusNotFound { + // Cache negative result for shorter time (1 minute) + p.cache.Set(hostname, nil, 1*time.Minute) + return nil, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API returned status %d", resp.StatusCode) + } + + // Parse response + var apiResponse RouteAPIResponse + if err := json.NewDecoder(resp.Body).Decode(&apiResponse); err != nil { + return nil, fmt.Errorf("failed to decode API response: %w", err) + } + + endpoints := apiResponse.Endpoints + + // Default target configuration + targetHost := p.localProxyAddr + targetPort := p.localProxyPort + + // If no endpoints returned, use local node + if len(endpoints) == 0 { + logger.Debug("No endpoints returned for hostname: %s, using local node", hostname) + } else { + // Select endpoint using consistent hashing for stickiness + selectedEndpoint := p.selectStickyEndpoint(clientAddr, endpoints) + targetHost = selectedEndpoint + targetPort = 443 // Default HTTPS port + logger.Debug("Selected endpoint %s for hostname %s from client %s", selectedEndpoint, hostname, clientAddr) + } + + route := &RouteRecord{ + Hostname: hostname, + TargetHost: targetHost, + TargetPort: targetPort, + } + + // Cache the result + p.cache.Set(hostname, route, cache.DefaultExpiration) + logger.Debug("Cached route for hostname: %s", hostname) + + return route, nil +} + +// selectStickyEndpoint selects an endpoint using consistent hashing to ensure +// the same client always routes to the same endpoint for load balancing +func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) string { + if len(endpoints) == 0 { + return p.localProxyAddr + } + if len(endpoints) == 1 { + return endpoints[0] + } + + // Use FNV hash for consistent selection based on client address + hash := fnv.New32a() + hash.Write([]byte(clientAddr)) + index := hash.Sum32() % uint32(len(endpoints)) + + return endpoints[index] +} + +// pipe handles bidirectional data transfer between connections +func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { + var wg sync.WaitGroup + wg.Add(2) + + // Copy data from client to target (using the buffered reader) + go func() { + defer wg.Done() + defer func() { + if tcpConn, ok := targetConn.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } + }() + + // Use a large buffer for better performance + buf := make([]byte, 32*1024) + _, err := io.CopyBuffer(targetConn, clientReader, buf) + if err != nil && err != io.EOF { + logger.Debug("Copy client->target error: %v", err) + } + }() + + // Copy data from target to client + go func() { + defer wg.Done() + defer func() { + if tcpConn, ok := clientConn.(*net.TCPConn); ok { + tcpConn.CloseWrite() + } + }() + + // Use a large buffer for better performance + buf := make([]byte, 32*1024) + _, err := io.CopyBuffer(clientConn, targetConn, buf) + if err != nil && err != io.EOF { + logger.Debug("Copy target->client error: %v", err) + } + }() + + wg.Wait() +} + +// GetCacheStats returns cache statistics +func (p *SNIProxy) GetCacheStats() (int, int) { + return p.cache.ItemCount(), len(p.cache.Items()) +} + +// ClearCache clears all cached entries +func (p *SNIProxy) ClearCache() { + p.cache.Flush() + log.Println("Cache cleared") +} + +// UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains +func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) { + newSNIs := make(map[string]struct{}) + for _, domain := range fullDomains { + newSNIs[domain] = struct{}{} + // Invalidate any cached route for this domain + p.cache.Delete(domain) + } + + // Update localSNIs + p.localSNIsLock.Lock() + removed := make([]string, 0) + for sni := range p.localSNIs { + if _, stillLocal := newSNIs[sni]; !stillLocal { + removed = append(removed, sni) + } + } + p.localSNIs = newSNIs + p.localSNIsLock.Unlock() + + logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed)) + + // Terminate tunnels for removed SNIs + if len(removed) > 0 { + p.activeTunnelsLock.Lock() + for _, sni := range removed { + if tunnels, ok := p.activeTunnels[sni]; ok { + for _, conn := range tunnels.conns { + conn.Close() + } + delete(p.activeTunnels, sni) + logger.Debug("Closed tunnels for SNI target change: %s", sni) + } + } + p.activeTunnelsLock.Unlock() + } +}