mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[proxy] Support WebSocket (#5312)
* Fix WebSocket support by implementing Hijacker interface Add responsewriter.PassthroughWriter to preserve optional HTTP interfaces (Hijacker, Flusher, Pusher) when wrapping http.ResponseWriter in middleware. Without this delegation: - WebSocket connections fail (can't hijack the connection) - Streaming breaks (can't flush buffers) - HTTP/2 push doesn't work * Add HijackTracker to manage hijacked connections during graceful shutdown * Refactor HijackTracker to use middleware for tracking hijacked connections * Refactor server handler chain setup for improved readability and maintainability
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
@@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
||||
|
||||
// Use a response writer wrapper so we can access the status code later.
|
||||
sw := &statusWriter{
|
||||
w: w,
|
||||
status: http.StatusOK,
|
||||
PassthroughWriter: responsewriter.New(w),
|
||||
status: http.StatusOK,
|
||||
}
|
||||
|
||||
// Resolve the source IP using trusted proxy configuration before passing
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
package accesslog
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
)
|
||||
|
||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
||||
// that captures the setting of the status code via the WriteHeader
|
||||
// function and stores it so that it can be retrieved later.
|
||||
// statusWriter captures the HTTP status code from WriteHeader calls.
|
||||
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
||||
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
||||
type statusWriter struct {
|
||||
w http.ResponseWriter
|
||||
*responsewriter.PassthroughWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) Header() http.Header {
|
||||
return w.w.Header()
|
||||
}
|
||||
|
||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
||||
return w.w.Write(data)
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.w.WriteHeader(status)
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
49
proxy/internal/conntrack/conn.go
Normal file
49
proxy/internal/conntrack/conn.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
|
||||
type trackedConn struct {
|
||||
net.Conn
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (c *trackedConn) Close() error {
|
||||
c.tracker.conns.Delete(c)
|
||||
return c.Conn.Close()
|
||||
}
|
||||
|
||||
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
|
||||
// to replace the raw connection with a trackedConn that auto-deregisters.
|
||||
type trackingWriter struct {
|
||||
http.ResponseWriter
|
||||
tracker *HijackTracker
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
conn, buf, err := hijacker.Hijack()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||
w.tracker.conns.Store(tc, struct{}{})
|
||||
return tc, buf, nil
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *trackingWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
41
proxy/internal/conntrack/hijacked.go
Normal file
41
proxy/internal/conntrack/hijacked.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package conntrack
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
|
||||
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||
// they must be tracked and closed explicitly during graceful shutdown.
|
||||
//
|
||||
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||
// connections are tracked and automatically deregistered when closed.
|
||||
type HijackTracker struct {
|
||||
conns sync.Map // net.Conn → struct{}
|
||||
}
|
||||
|
||||
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||
// hijacked connections are tracked and automatically deregistered from the
|
||||
// tracker when closed. This should be the outermost middleware in the chain.
|
||||
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||
})
|
||||
}
|
||||
|
||||
// CloseAll closes all tracked hijacked connections and returns the number
|
||||
// of connections that were closed.
|
||||
func (t *HijackTracker) CloseAll() int {
|
||||
var count int
|
||||
t.conns.Range(func(key, _ any) bool {
|
||||
if conn, ok := key.(net.Conn); ok {
|
||||
_ = conn.Close()
|
||||
count++
|
||||
}
|
||||
t.conns.Delete(key)
|
||||
return true
|
||||
})
|
||||
return count
|
||||
}
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||
)
|
||||
|
||||
type Metrics struct {
|
||||
@@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics {
|
||||
}
|
||||
|
||||
type responseInterceptor struct {
|
||||
http.ResponseWriter
|
||||
*responsewriter.PassthroughWriter
|
||||
status int
|
||||
size int
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) WriteHeader(status int) {
|
||||
w.status = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
w.PassthroughWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||
size, err := w.ResponseWriter.Write(b)
|
||||
size, err := w.PassthroughWriter.Write(b)
|
||||
w.size += size
|
||||
return size, err
|
||||
}
|
||||
@@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
||||
m.requestsTotal.Inc()
|
||||
m.activeRequests.Inc()
|
||||
|
||||
interceptor := &responseInterceptor{ResponseWriter: w}
|
||||
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
||||
|
||||
start := time.Now()
|
||||
next.ServeHTTP(interceptor, r)
|
||||
|
||||
53
proxy/internal/responsewriter/responsewriter.go
Normal file
53
proxy/internal/responsewriter/responsewriter.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package responsewriter
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"net"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
|
||||
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
|
||||
// ResponseWriter if it supports them.
|
||||
//
|
||||
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
|
||||
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
|
||||
// and HTTP/2 server push.
|
||||
type PassthroughWriter struct {
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
// New creates a new wrapper around the given ResponseWriter.
|
||||
func New(w http.ResponseWriter) *PassthroughWriter {
|
||||
return &PassthroughWriter{ResponseWriter: w}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
|
||||
// This is required for WebSocket connections and other protocol upgrades.
|
||||
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
return nil, nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Flush() {
|
||||
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
|
||||
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
|
||||
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||
return pusher.Push(target, opts)
|
||||
}
|
||||
return http.ErrNotSupported
|
||||
}
|
||||
|
||||
// Unwrap returns the underlying ResponseWriter.
|
||||
// This is required for http.ResponseController (Go 1.20+) to work correctly.
|
||||
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
|
||||
return w.ResponseWriter
|
||||
}
|
||||
@@ -37,6 +37,7 @@ import (
|
||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||
@@ -64,6 +65,11 @@ type Server struct {
|
||||
healthChecker *health.Checker
|
||||
meter *metrics.Metrics
|
||||
|
||||
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||
// does not handle them.
|
||||
hijackTracker conntrack.HijackTracker
|
||||
|
||||
// Mostly used for debugging on management.
|
||||
startTime time.Time
|
||||
|
||||
@@ -185,10 +191,18 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Build the handler chain from inside out.
|
||||
handler := http.Handler(s.proxy)
|
||||
handler = s.auth.Protect(handler)
|
||||
handler = web.AssetHandler(handler)
|
||||
handler = accessLog.Middleware(handler)
|
||||
handler = s.meter.Middleware(handler)
|
||||
handler = s.hijackTracker.Middleware(handler)
|
||||
|
||||
// Start the reverse proxy HTTPS server.
|
||||
s.https = &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
||||
Handler: handler,
|
||||
TLSConfig: tlsConfig,
|
||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||
}
|
||||
@@ -457,7 +471,12 @@ func (s *Server) gracefulShutdown() {
|
||||
s.Logger.Warnf("https server drain: %v", err)
|
||||
}
|
||||
|
||||
// Step 4: Stop all remaining background services.
|
||||
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||
if n := s.hijackTracker.CloseAll(); n > 0 {
|
||||
s.Logger.Infof("closed %d hijacked connection(s)", n)
|
||||
}
|
||||
|
||||
// Step 5: Stop all remaining background services.
|
||||
s.shutdownServices()
|
||||
s.Logger.Info("graceful shutdown complete")
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user