[proxy] Support WebSocket (#5312)

* Fix WebSocket support by implementing Hijacker interface

Add responsewriter.PassthroughWriter to preserve optional HTTP interfaces
(Hijacker, Flusher, Pusher) when wrapping http.ResponseWriter in middleware.

Without this delegation:
 - WebSocket connections fail (can't hijack the connection)
 - Streaming breaks (can't flush buffers)
 - HTTP/2 push doesn't work

* Add HijackTracker to manage hijacked connections during graceful shutdown

* Refactor HijackTracker to use middleware for tracking hijacked connections

* Refactor server handler chain setup for improved readability and maintainability
This commit is contained in:
Zoltan Papp
2026-02-17 12:53:34 +01:00
committed by GitHub
parent 0146e39714
commit 1bd7190954
7 changed files with 180 additions and 23 deletions

View File

@@ -37,6 +37,7 @@ import (
"github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/proxy/internal/auth"
"github.com/netbirdio/netbird/proxy/internal/certwatch"
"github.com/netbirdio/netbird/proxy/internal/conntrack"
"github.com/netbirdio/netbird/proxy/internal/debug"
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
"github.com/netbirdio/netbird/proxy/internal/health"
@@ -64,6 +65,11 @@ type Server struct {
healthChecker *health.Checker
meter *metrics.Metrics
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
// so they can be closed during graceful shutdown, since http.Server.Shutdown
// does not handle them.
hijackTracker conntrack.HijackTracker
// Mostly used for debugging on management.
startTime time.Time
@@ -185,10 +191,18 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
return err
}
// Build the handler chain from inside out.
handler := http.Handler(s.proxy)
handler = s.auth.Protect(handler)
handler = web.AssetHandler(handler)
handler = accessLog.Middleware(handler)
handler = s.meter.Middleware(handler)
handler = s.hijackTracker.Middleware(handler)
// Start the reverse proxy HTTPS server.
s.https = &http.Server{
Addr: addr,
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
Handler: handler,
TLSConfig: tlsConfig,
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
}
@@ -457,7 +471,12 @@ func (s *Server) gracefulShutdown() {
s.Logger.Warnf("https server drain: %v", err)
}
// Step 4: Stop all remaining background services.
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
if n := s.hijackTracker.CloseAll(); n > 0 {
s.Logger.Infof("closed %d hijacked connection(s)", n)
}
// Step 5: Stop all remaining background services.
s.shutdownServices()
s.Logger.Info("graceful shutdown complete")
}