diff --git a/proxy/proxy.go b/proxy/proxy.go index f29878e..7af063e 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -14,10 +14,12 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/fosrl/gerbil/logger" "github.com/patrickmn/go-cache" + "golang.org/x/sync/errgroup" ) // RouteRecord represents a routing configuration @@ -72,7 +74,9 @@ type SNIProxy struct { } type activeTunnel struct { - conns []net.Conn + ctx context.Context + cancel context.CancelFunc + count atomic.Int64 } // readOnlyConn is a wrapper for io.Reader that implements net.Conn @@ -588,37 +592,33 @@ func (p *SNIProxy) handleConnection(clientConn net.Conn) { } } - // Track this tunnel by SNI + // Track this tunnel by SNI using context for cancellation p.activeTunnelsLock.Lock() tunnel, ok := p.activeTunnels[hostname] if !ok { - tunnel = &activeTunnel{} + ctx, cancel := context.WithCancel(p.ctx) + tunnel = &activeTunnel{ctx: ctx, cancel: cancel} p.activeTunnels[hostname] = tunnel } - tunnel.conns = append(tunnel.conns, actualClientConn) + tunnel.count.Add(1) + tunnelCtx := tunnel.ctx p.activeTunnelsLock.Unlock() defer func() { - // Remove this conn from active tunnels - p.activeTunnelsLock.Lock() - if tunnel, ok := p.activeTunnels[hostname]; ok { - newConns := make([]net.Conn, 0, len(tunnel.conns)) - for _, c := range tunnel.conns { - if c != actualClientConn { - newConns = append(newConns, c) - } - } - if len(newConns) == 0 { + // Decrement count atomically; if we're the last connection, clean up + if tunnel.count.Add(-1) == 0 { + tunnel.cancel() + p.activeTunnelsLock.Lock() + // Only delete if the map still points to our tunnel + if p.activeTunnels[hostname] == tunnel { delete(p.activeTunnels, hostname) - } else { - tunnel.conns = newConns } + p.activeTunnelsLock.Unlock() } - p.activeTunnelsLock.Unlock() }() - // Start bidirectional data transfer - p.pipe(actualClientConn, targetConn, clientReader) + // Start bidirectional data transfer with tunnel context + p.pipe(tunnelCtx, actualClientConn, targetConn, clientReader) } // getRoute retrieves routing information for a hostname @@ -754,47 +754,36 @@ func (p *SNIProxy) selectStickyEndpoint(clientAddr string, endpoints []string) s } // pipe handles bidirectional data transfer between connections -func (p *SNIProxy) pipe(clientConn, targetConn net.Conn, clientReader io.Reader) { - var wg sync.WaitGroup - wg.Add(2) +func (p *SNIProxy) pipe(ctx context.Context, clientConn, targetConn net.Conn, clientReader io.Reader) { + g, ctx := errgroup.WithContext(ctx) - // closeOnce ensures we only close connections once - var closeOnce sync.Once - closeConns := func() { - closeOnce.Do(func() { - // Close both connections to unblock any pending reads - clientConn.Close() - targetConn.Close() - }) - } + // Close connections when context cancels to unblock io.Copy operations + context.AfterFunc(ctx, func() { + clientConn.Close() + targetConn.Close() + }) - // Copy data from client to target (using the buffered reader) - go func() { - defer wg.Done() - defer closeConns() - - // Use a large buffer for better performance + // Copy data from client to target + g.Go(func() error { buf := make([]byte, 32*1024) _, err := io.CopyBuffer(targetConn, clientReader, buf) if err != nil && err != io.EOF { logger.Debug("Copy client->target error: %v", err) } - }() + return err + }) // Copy data from target to client - go func() { - defer wg.Done() - defer closeConns() - - // Use a large buffer for better performance + g.Go(func() error { buf := make([]byte, 32*1024) _, err := io.CopyBuffer(clientConn, targetConn, buf) if err != nil && err != io.EOF { logger.Debug("Copy target->client error: %v", err) } - }() + return err + }) - wg.Wait() + g.Wait() } // GetCacheStats returns cache statistics @@ -830,16 +819,14 @@ func (p *SNIProxy) UpdateLocalSNIs(fullDomains []string) { logger.Debug("Updated local SNIs, added %d, removed %d", len(newSNIs), len(removed)) - // Terminate tunnels for removed SNIs + // Terminate tunnels for removed SNIs via context cancellation if len(removed) > 0 { p.activeTunnelsLock.Lock() for _, sni := range removed { - if tunnels, ok := p.activeTunnels[sni]; ok { - for _, conn := range tunnels.conns { - conn.Close() - } + if tunnel, ok := p.activeTunnels[sni]; ok { + tunnel.cancel() delete(p.activeTunnels, sni) - logger.Debug("Closed tunnels for SNI target change: %s", sni) + logger.Debug("Cancelled tunnel context for SNI target change: %s", sni) } } p.activeTunnelsLock.Unlock()