mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[proxy] Support WebSocket (#5312)
* Fix WebSocket support by implementing Hijacker interface Add responsewriter.PassthroughWriter to preserve optional HTTP interfaces (Hijacker, Flusher, Pusher) when wrapping http.ResponseWriter in middleware. Without this delegation: - WebSocket connections fail (can't hijack the connection) - Streaming breaks (can't flush buffers) - HTTP/2 push doesn't work * Add HijackTracker to manage hijacked connections during graceful shutdown * Refactor HijackTracker to use middleware for tracking hijacked connections * Refactor server handler chain setup for improved readability and maintainability
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
"github.com/netbirdio/netbird/proxy/web"
|
"github.com/netbirdio/netbird/proxy/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,8 +28,8 @@ func (l *Logger) Middleware(next http.Handler) http.Handler {
|
|||||||
|
|
||||||
// Use a response writer wrapper so we can access the status code later.
|
// Use a response writer wrapper so we can access the status code later.
|
||||||
sw := &statusWriter{
|
sw := &statusWriter{
|
||||||
w: w,
|
PassthroughWriter: responsewriter.New(w),
|
||||||
status: http.StatusOK,
|
status: http.StatusOK,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve the source IP using trusted proxy configuration before passing
|
// Resolve the source IP using trusted proxy configuration before passing
|
||||||
|
|||||||
@@ -1,26 +1,18 @@
|
|||||||
package accesslog
|
package accesslog
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
// statusWriter is a simple wrapper around an http.ResponseWriter
|
// statusWriter captures the HTTP status code from WriteHeader calls.
|
||||||
// that captures the setting of the status code via the WriteHeader
|
// It embeds responsewriter.PassthroughWriter which handles all the optional
|
||||||
// function and stores it so that it can be retrieved later.
|
// interfaces (Hijacker, Flusher, Pusher) automatically.
|
||||||
type statusWriter struct {
|
type statusWriter struct {
|
||||||
w http.ResponseWriter
|
*responsewriter.PassthroughWriter
|
||||||
status int
|
status int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *statusWriter) Header() http.Header {
|
|
||||||
return w.w.Header()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *statusWriter) Write(data []byte) (int, error) {
|
|
||||||
return w.w.Write(data)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *statusWriter) WriteHeader(status int) {
|
func (w *statusWriter) WriteHeader(status int) {
|
||||||
w.status = status
|
w.status = status
|
||||||
w.w.WriteHeader(status)
|
w.PassthroughWriter.WriteHeader(status)
|
||||||
}
|
}
|
||||||
|
|||||||
49
proxy/internal/conntrack/conn.go
Normal file
49
proxy/internal/conntrack/conn.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// trackedConn wraps a net.Conn and removes itself from the tracker on Close.
|
||||||
|
type trackedConn struct {
|
||||||
|
net.Conn
|
||||||
|
tracker *HijackTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *trackedConn) Close() error {
|
||||||
|
c.tracker.conns.Delete(c)
|
||||||
|
return c.Conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// trackingWriter wraps an http.ResponseWriter and intercepts Hijack calls
|
||||||
|
// to replace the raw connection with a trackedConn that auto-deregisters.
|
||||||
|
type trackingWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
tracker *HijackTracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *trackingWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
hijacker, ok := w.ResponseWriter.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
conn, buf, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
tc := &trackedConn{Conn: conn, tracker: w.tracker}
|
||||||
|
w.tracker.conns.Store(tc, struct{}{})
|
||||||
|
return tc, buf, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *trackingWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *trackingWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return w.ResponseWriter
|
||||||
|
}
|
||||||
41
proxy/internal/conntrack/hijacked.go
Normal file
41
proxy/internal/conntrack/hijacked.go
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HijackTracker tracks connections that have been hijacked (e.g. WebSocket
|
||||||
|
// upgrades). http.Server.Shutdown does not close hijacked connections, so
|
||||||
|
// they must be tracked and closed explicitly during graceful shutdown.
|
||||||
|
//
|
||||||
|
// Use Middleware as the outermost HTTP middleware to ensure hijacked
|
||||||
|
// connections are tracked and automatically deregistered when closed.
|
||||||
|
type HijackTracker struct {
|
||||||
|
conns sync.Map // net.Conn → struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware returns an HTTP middleware that wraps the ResponseWriter so that
|
||||||
|
// hijacked connections are tracked and automatically deregistered from the
|
||||||
|
// tracker when closed. This should be the outermost middleware in the chain.
|
||||||
|
func (t *HijackTracker) Middleware(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next.ServeHTTP(&trackingWriter{ResponseWriter: w, tracker: t}, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseAll closes all tracked hijacked connections and returns the number
|
||||||
|
// of connections that were closed.
|
||||||
|
func (t *HijackTracker) CloseAll() int {
|
||||||
|
var count int
|
||||||
|
t.conns.Range(func(key, _ any) bool {
|
||||||
|
if conn, ok := key.(net.Conn); ok {
|
||||||
|
_ = conn.Close()
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
t.conns.Delete(key)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return count
|
||||||
|
}
|
||||||
@@ -5,9 +5,11 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
"github.com/prometheus/client_golang/prometheus/promauto"
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/responsewriter"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Metrics struct {
|
type Metrics struct {
|
||||||
@@ -60,18 +62,18 @@ func New(reg prometheus.Registerer) *Metrics {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type responseInterceptor struct {
|
type responseInterceptor struct {
|
||||||
http.ResponseWriter
|
*responsewriter.PassthroughWriter
|
||||||
status int
|
status int
|
||||||
size int
|
size int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *responseInterceptor) WriteHeader(status int) {
|
func (w *responseInterceptor) WriteHeader(status int) {
|
||||||
w.status = status
|
w.status = status
|
||||||
w.ResponseWriter.WriteHeader(status)
|
w.PassthroughWriter.WriteHeader(status)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
func (w *responseInterceptor) Write(b []byte) (int, error) {
|
||||||
size, err := w.ResponseWriter.Write(b)
|
size, err := w.PassthroughWriter.Write(b)
|
||||||
w.size += size
|
w.size += size
|
||||||
return size, err
|
return size, err
|
||||||
}
|
}
|
||||||
@@ -81,7 +83,7 @@ func (m *Metrics) Middleware(next http.Handler) http.Handler {
|
|||||||
m.requestsTotal.Inc()
|
m.requestsTotal.Inc()
|
||||||
m.activeRequests.Inc()
|
m.activeRequests.Inc()
|
||||||
|
|
||||||
interceptor := &responseInterceptor{ResponseWriter: w}
|
interceptor := &responseInterceptor{PassthroughWriter: responsewriter.New(w)}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
next.ServeHTTP(interceptor, r)
|
next.ServeHTTP(interceptor, r)
|
||||||
|
|||||||
53
proxy/internal/responsewriter/responsewriter.go
Normal file
53
proxy/internal/responsewriter/responsewriter.go
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
package responsewriter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PassthroughWriter wraps an http.ResponseWriter and preserves optional
|
||||||
|
// interfaces like Hijacker, Flusher, and Pusher by delegating to the underlying
|
||||||
|
// ResponseWriter if it supports them.
|
||||||
|
//
|
||||||
|
// This is the standard pattern for Go middleware that needs to wrap ResponseWriter
|
||||||
|
// while maintaining support for protocol upgrades (WebSocket), streaming (Flusher),
|
||||||
|
// and HTTP/2 server push.
|
||||||
|
type PassthroughWriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new wrapper around the given ResponseWriter.
|
||||||
|
func New(w http.ResponseWriter) *PassthroughWriter {
|
||||||
|
return &PassthroughWriter{ResponseWriter: w}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack implements http.Hijacker interface if the underlying ResponseWriter supports it.
|
||||||
|
// This is required for WebSocket connections and other protocol upgrades.
|
||||||
|
func (w *PassthroughWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacker, ok := w.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
return nil, nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush implements http.Flusher interface if the underlying ResponseWriter supports it.
|
||||||
|
func (w *PassthroughWriter) Flush() {
|
||||||
|
if flusher, ok := w.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push implements http.Pusher interface if the underlying ResponseWriter supports it.
|
||||||
|
func (w *PassthroughWriter) Push(target string, opts *http.PushOptions) error {
|
||||||
|
if pusher, ok := w.ResponseWriter.(http.Pusher); ok {
|
||||||
|
return pusher.Push(target, opts)
|
||||||
|
}
|
||||||
|
return http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unwrap returns the underlying ResponseWriter.
|
||||||
|
// This is required for http.ResponseController (Go 1.20+) to work correctly.
|
||||||
|
func (w *PassthroughWriter) Unwrap() http.ResponseWriter {
|
||||||
|
return w.ResponseWriter
|
||||||
|
}
|
||||||
@@ -37,6 +37,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/proxy/internal/acme"
|
"github.com/netbirdio/netbird/proxy/internal/acme"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/auth"
|
"github.com/netbirdio/netbird/proxy/internal/auth"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
"github.com/netbirdio/netbird/proxy/internal/certwatch"
|
||||||
|
"github.com/netbirdio/netbird/proxy/internal/conntrack"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/debug"
|
"github.com/netbirdio/netbird/proxy/internal/debug"
|
||||||
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
proxygrpc "github.com/netbirdio/netbird/proxy/internal/grpc"
|
||||||
"github.com/netbirdio/netbird/proxy/internal/health"
|
"github.com/netbirdio/netbird/proxy/internal/health"
|
||||||
@@ -64,6 +65,11 @@ type Server struct {
|
|||||||
healthChecker *health.Checker
|
healthChecker *health.Checker
|
||||||
meter *metrics.Metrics
|
meter *metrics.Metrics
|
||||||
|
|
||||||
|
// hijackTracker tracks hijacked connections (e.g. WebSocket upgrades)
|
||||||
|
// so they can be closed during graceful shutdown, since http.Server.Shutdown
|
||||||
|
// does not handle them.
|
||||||
|
hijackTracker conntrack.HijackTracker
|
||||||
|
|
||||||
// Mostly used for debugging on management.
|
// Mostly used for debugging on management.
|
||||||
startTime time.Time
|
startTime time.Time
|
||||||
|
|
||||||
@@ -185,10 +191,18 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build the handler chain from inside out.
|
||||||
|
handler := http.Handler(s.proxy)
|
||||||
|
handler = s.auth.Protect(handler)
|
||||||
|
handler = web.AssetHandler(handler)
|
||||||
|
handler = accessLog.Middleware(handler)
|
||||||
|
handler = s.meter.Middleware(handler)
|
||||||
|
handler = s.hijackTracker.Middleware(handler)
|
||||||
|
|
||||||
// Start the reverse proxy HTTPS server.
|
// Start the reverse proxy HTTPS server.
|
||||||
s.https = &http.Server{
|
s.https = &http.Server{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
Handler: s.meter.Middleware(accessLog.Middleware(web.AssetHandler(s.auth.Protect(s.proxy)))),
|
Handler: handler,
|
||||||
TLSConfig: tlsConfig,
|
TLSConfig: tlsConfig,
|
||||||
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
ErrorLog: newHTTPServerLogger(s.Logger, logtagValueHTTPS),
|
||||||
}
|
}
|
||||||
@@ -457,7 +471,12 @@ func (s *Server) gracefulShutdown() {
|
|||||||
s.Logger.Warnf("https server drain: %v", err)
|
s.Logger.Warnf("https server drain: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4: Stop all remaining background services.
|
// Step 4: Close hijacked connections (WebSocket) that Shutdown does not handle.
|
||||||
|
if n := s.hijackTracker.CloseAll(); n > 0 {
|
||||||
|
s.Logger.Infof("closed %d hijacked connection(s)", n)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Stop all remaining background services.
|
||||||
s.shutdownServices()
|
s.shutdownServices()
|
||||||
s.Logger.Info("graceful shutdown complete")
|
s.Logger.Info("graceful shutdown complete")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user