diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index ca7556bfd..dd4798975 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -9,6 +9,7 @@ import ( "github.com/rs/xid" "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/responsewriter" "github.com/netbirdio/netbird/proxy/web" ) @@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { // Use a response writer wrapper so we can access the status code later. sw := &statusWriter{ - w: w, - status: http.StatusOK, + PassthroughWriter: responsewriter.New(w), + status: http.StatusOK, } // Resolve the source IP using trusted proxy configuration before passing diff --git a/proxy/internal/accesslog/statuswriter.go b/proxy/internal/accesslog/statuswriter.go index 56ef90efa..43cda59f9 100644 --- a/proxy/internal/accesslog/statuswriter.go +++ b/proxy/internal/accesslog/statuswriter.go @@ -1,26 +1,18 @@ package accesslog import ( - "net/http" + "github.com/netbirdio/netbird/proxy/internal/responsewriter" ) -// statusWriter is a simple wrapper around an http.ResponseWriter -// that captures the setting of the status code via the WriteHeader -// function and stores it so that it can be retrieved later. +// statusWriter captures the HTTP status code from WriteHeader calls. +// It embeds responsewriter.PassthroughWriter which handles all the optional +// interfaces (Hijacker, Flusher, Pusher) automatically. type statusWriter struct { - w http.ResponseWriter + *responsewriter.PassthroughWriter status int } -func (w *statusWriter) Header() http.Header { - return w.w.Header() -} - -func (w *statusWriter) Write(data []byte) (int, error) { - return w.w.Write(data) -} - func (w *statusWriter) WriteHeader(status int) { w.status = status - w.w.WriteHeader(status) + w.PassthroughWriter.WriteHeader(status) } diff --git a/proxy/internal/conntrack/conn.go b/proxy/internal/conntrack/conn.go new file mode 100644 index 000000000..97055d992 --- /dev/null +++ b/proxy/internal/conntrack/conn.go @@ -0,0 +1,49 @@ +package conntrack + +import ( + "bufio" + "net" + "net/http" +) + +// trackedConn wraps a net.Conn and removes itself from the tracker on Close. +type trackedConn struct { + net.Conn + tracker *HijackTracker +} + +func (c *trackedConn) Close() error { + c.tracker.conns.Delete(c) + return c.Conn.Close() +} + +// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls +// to replace the raw connection with a trackedConn that auto-deregisters. +type trackingWriter struct { + http.ResponseWriter + tracker *HijackTracker +} + +func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + hijacker, ok := w.ResponseWriter.(http.Hijacker) + if !ok { + return nil, nil, http.ErrNotSupported + } + conn, buf, err := hijacker.Hijack() + if err != nil { + return nil, nil, err + } + tc := &trackedConn{Conn: conn, tracker: w.tracker} + w.tracker.conns.Store(tc, struct{}{}) + return tc, buf, nil +} + +func (w *trackingWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +func (w *trackingWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} diff --git a/proxy/internal/conntrack/hijacked.go b/proxy/internal/conntrack/hijacked.go new file mode 100644 index 000000000..d76cebc08 --- /dev/null +++ b/proxy/internal/conntrack/hijacked.go @@ -0,0 +1,41 @@ +package conntrack + +import ( + "net" + "net/http" + "sync" +) + +// HijackTracker tracks connections that have been hijacked (e.g. WebSocket +// upgrades). http.Server.Shutdown does not close hijacked connections, so +// they must be tracked and closed explicitly during graceful shutdown. +// +// Use Middleware as the outermost HTTP middleware to ensure hijacked +// connections are tracked and automatically deregistered when closed. +type HijackTracker struct { + conns sync.Map // net.Conn → struct{} +} + +// Middleware returns an HTTP middleware that wraps the ResponseWriter so that +// hijacked connections are tracked and automatically deregistered from the +// tracker when closed. This should be the outermost middleware in the chain. +func (t *HijackTracker) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r) + }) +} + +// CloseAll closes all tracked hijacked connections and returns the number +// of connections that were closed. +func (t *HijackTracker) CloseAll() int { + var count int + t.conns.Range(func(key, _ any) bool { + if conn, ok := key.(net.Conn); ok { + _ = conn.Close() + count++ + } + t.conns.Delete(key) + return true + }) + return count +} diff --git a/proxy/internal/metrics/metrics.go b/proxy/internal/metrics/metrics.go index 951ce73dd..954020f77 100644 --- a/proxy/internal/metrics/metrics.go +++ b/proxy/internal/metrics/metrics.go @@ -5,9 +5,11 @@ import ( "strconv" "time" - "github.com/netbirdio/netbird/proxy/internal/proxy" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + + "github.com/netbirdio/netbird/proxy/internal/proxy" + "github.com/netbirdio/netbird/proxy/internal/responsewriter" ) type Metrics struct { @@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics { } type responseInterceptor struct { - http.ResponseWriter + *responsewriter.PassthroughWriter status int size int } func (w *responseInterceptor) WriteHeader(status int) { w.status = status - w.ResponseWriter.WriteHeader(status) + w.PassthroughWriter.WriteHeader(status) } func (w *responseInterceptor) Write(b []byte) (int, error) { - size, err := w.ResponseWriter.Write(b) + size, err := w.PassthroughWriter.Write(b) w.size += size return size, err } @@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler { m.requestsTotal.Inc() m.activeRequests.Inc() - interceptor := &responseInterceptor{ResponseWriter: w} + interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)} start := time.Now() next.ServeHTTP(interceptor, r) diff --git a/proxy/internal/responsewriter/responsewriter.go b/proxy/internal/responsewriter/responsewriter.go new file mode 100644 index 000000000..b8fc95f2d --- /dev/null +++ b/proxy/internal/responsewriter/responsewriter.go @@ -0,0 +1,53 @@ +package responsewriter + +import ( + "bufio" + "net" + "net/http" +) + +// PassthroughWriter wraps an http.ResponseWriter and preserves optional +// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying +// ResponseWriter if it supports them. +// +// This is the standard pattern for Go middleware that needs to wrap ResponseWriter +// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher), +// and HTTP/2 server push. +type PassthroughWriter struct { + http.ResponseWriter +} + +// New creates a new wrapper around the given ResponseWriter. +func New(w http.ResponseWriter) *PassthroughWriter { + return &PassthroughWriter{ResponseWriter: w} +} + +// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it. +// This is required for WebSocket connections and other protocol upgrades. +func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { + if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok { + return hijacker.Hijack() + } + return nil, nil, http.ErrNotSupported +} + +// Flush implements http.Flusher interface if the underlying ResponseWriter supports it. +func (w *PassthroughWriter) Flush() { + if flusher, ok := w.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} + +// Push implements http.Pusher interface if the underlying ResponseWriter supports it. +func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error { + if pusher, ok := w.ResponseWriter.(http.Pusher); ok { + return pusher.Push(target, opts) + } + return http.ErrNotSupported +} + +// Unwrap returns the underlying ResponseWriter. +// This is required for http.ResponseController (Go 1.20+) to work correctly. +func (w *PassthroughWriter) Unwrap() http.ResponseWriter { + return w.ResponseWriter +} diff --git a/proxy/server.go b/proxy/server.go index b08837679..52b4972ec 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -37,6 +37,7 @@ import ( "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" "github.com/netbirdio/netbird/proxy/internal/certwatch" + "github.com/netbirdio/netbird/proxy/internal/conntrack" "github.com/netbirdio/netbird/proxy/internal/debug" proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc" "github.com/netbirdio/netbird/proxy/internal/health" @@ -64,6 +65,11 @@ type Server struct { healthChecker *health.Checker meter *metrics.Metrics + // hijackTracker tracks hijacked connections (e.g. WebSocket upgrades) + // so they can be closed during graceful shutdown, since http.Server.Shutdown + // does not handle them. + hijackTracker conntrack.HijackTracker + // Mostly used for debugging on management. startTime time.Time @@ -185,10 +191,18 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { return err } + // Build the handler chain from inside out. + handler := http.Handler(s.proxy) + handler = s.auth.Protect(handler) + handler = web.AssetHandler(handler) + handler = accessLog.Middleware(handler) + handler = s.meter.Middleware(handler) + handler = s.hijackTracker.Middleware(handler) + // Start the reverse proxy HTTPS server. s.https = &http.Server{ Addr: addr, - Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))), + Handler: handler, TLSConfig: tlsConfig, ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS), } @@ -457,7 +471,12 @@ func (s *Server) gracefulShutdown() { s.Logger.Warnf("https server drain: %v", err) } - // Step 4: Stop all remaining background services. + // Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle. + if n := s.hijackTracker.CloseAll(); n > 0 { + s.Logger.Infof("closed %d hijacked connection(s)", n) + } + + // Step 5: Stop all remaining background services. s.shutdownServices() s.Logger.Info("graceful shutdown complete") }