diff --git a/proxy/proxy.go b/proxy/proxy.go index f29878e..5c905f0 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -72,7 +72,7 @@ type SNIProxy struct { } type activeTunnel struct { - conns []net.Conn + conns map[net.Conn]struct{} } // readOnlyConn is a wrapper for io.Reader that implements net.Conn @@ -592,26 +592,19 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { p.activeTunnelsLock.Lock() tunnel, ok := p.activeTunnels[hostname] if !ok { - tunnel = &activeTunnel{} + tunnel = &activeTunnel{conns: make(map[net.Conn]struct{})} p.activeTunnels[hostname] = tunnel } - tunnel.conns = append(tunnel.conns, actualClientConn) + tunnel.conns[actualClientConn] = struct{}{} p.activeTunnelsLock.Unlock() defer func() { - // Remove this conn from active tunnels + // Remove this conn from active tunnels - O(1) with map p.activeTunnelsLock.Lock() if tunnel, ok := p.activeTunnels[hostname]; ok { - newConns := make([]net.Conn, 0, len(tunnel.conns)) - for _, c := range tunnel.conns { - if c != actualClientConn { - newConns = append(newConns, c) - } - } - if len(newConns) == 0 { + delete(tunnel.conns, actualClientConn) + if len(tunnel.conns) == 0 { delete(p.activeTunnels, hostname) - } else { - tunnel.conns = newConns } } p.activeTunnelsLock.Unlock() @@ -810,32 +803,42 @@ func (p *SNIProxy) ClearCache() { // UpdateLocalSNIs updates the local SNIs and invalidates cache for changed domains func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) { - newSNIs := make(map[string]struct{}) + newSNIs := make(map[string]struct{}, len(fullDomains)) for _, domain := range fullDomains { newSNIs[domain] = struct{}{} - // Invalidate any cached route for this domain - p.cache.Delete(domain) } - // Update localSNIs - p.localSNIsLock.Lock() + // Get old SNIs with read lock to compute diff outside write lock + p.localSNIsLock.RLock() + oldSNIs := p.localSNIs + p.localSNIsLock.RUnlock() + + // Compute removed SNIs outside the lock removed := make([]string, 0) - for sni := range p.localSNIs { + for sni := range oldSNIs { if _, stillLocal := newSNIs[sni]; !stillLocal { removed = append(removed, sni) } } + + // Swap with minimal write lock hold time + p.localSNIsLock.Lock() p.localSNIs = newSNIs p.localSNIsLock.Unlock() + // Invalidate cache for new domains (cache is thread-safe) + for domain := range newSNIs { + p.cache.Delete(domain) + } + logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed)) // Terminate tunnels for removed SNIs if len(removed) > 0 { p.activeTunnelsLock.Lock() for _, sni := range removed { - if tunnels, ok := p.activeTunnels[sni]; ok { - for _, conn := range tunnels.conns { + if tunnel, ok := p.activeTunnels[sni]; ok { + for conn := range tunnel.conns { conn.Close() } delete(p.activeTunnels, sni)