Files
netbird/proxy/internal/proxy/reverseproxy.go
Viktor Liu 07e59b2708 Add reverse proxy header security and forwarding
- Rewrite Host header to backend target (configurable via pass_host_header per mapping)
- Strip and set X-Forwarded-For/X-Real-IP from direct connection (trust boundary)
- Set X-Forwarded-Host and X-Forwarded-Proto headers
- Strip nb_session cookie and session_token query param before forwarding
- Add --forwarded-proto flag (auto/http/https) for proto detection
- Fix OIDC redirect hardcoded https scheme
- Add pass_host_header to proto, API, and management model
2026-02-08 15:00:35 +08:00

229 lines
8.3 KiB
Go

package proxy
import (
"context"
"errors"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/web"
)
type ReverseProxy struct {
transport http.RoundTripper
// forwardedProto overrides the X-Forwarded-Proto header value.
// Valid values: "auto" (detect from TLS), "http", "https".
forwardedProto string
mappingsMux sync.RWMutex
mappings map[string]Mapping
logger *log.Logger
}
// NewReverseProxy configures a new NetBird ReverseProxy.
// This is a wrapper around an httputil.ReverseProxy set
// to dynamically route requests based on internal mapping
// between requested URLs and targets.
// The internal mappings can be modified using the AddMapping
// and RemoveMapping functions.
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, logger *log.Logger) *ReverseProxy {
if logger == nil {
logger = log.StandardLogger()
}
return &ReverseProxy{
transport: transport,
forwardedProto: forwardedProto,
mappings: make(map[string]Mapping),
logger: logger,
}
}
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
result, exists := p.findTargetForRequest(r)
if !exists {
requestID := getRequestID(r)
web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found",
"The requested service could not be found. Please check the URL, try refreshing, or check if the peer is running. If that doesn't work, see our documentation for help.",
requestID, web.ErrorStatus{Proxy: true, Peer: false, Destination: false})
return
}
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), result.serviceID)
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, result.accountID)
// Set the accountId in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
// pointer in the context, and mutate the struct here so outer middleware can read it.
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
capturedData.SetServiceId(result.serviceID)
capturedData.SetAccountId(result.accountID)
}
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(result.url, result.passHostHeader),
Transport: p.transport,
ErrorHandler: proxyErrorHandler,
}
rp.ServeHTTP(w, r.WithContext(ctx))
}
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
// inbound requests to target the backend service while setting security-relevant
// forwarding headers and stripping proxy authentication credentials.
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
func (p *ReverseProxy) rewriteFunc(target *url.URL, passHostHeader bool) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
r.SetURL(target)
if passHostHeader {
r.Out.Host = r.In.Host
} else {
r.Out.Host = target.Host
}
clientIP := extractClientIP(r.In.RemoteAddr)
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
// Strip any incoming forwarding headers since this proxy is the trust
// boundary and set them fresh based on the direct connection.
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Real-IP", clientIP)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
stripSessionCookie(r)
stripSessionTokenQuery(r)
}
}
// stripSessionCookie removes the proxy's session cookie from the outgoing
// request while preserving all other cookies.
func stripSessionCookie(r *httputil.ProxyRequest) {
cookies := r.In.Cookies()
r.Out.Header.Del("Cookie")
for _, c := range cookies {
if c.Name != auth.SessionCookieName {
r.Out.AddCookie(c)
}
}
}
// stripSessionTokenQuery removes the OIDC session_token query parameter from
// the outgoing URL to prevent credential leakage to backends.
func stripSessionTokenQuery(r *httputil.ProxyRequest) {
q := r.Out.URL.Query()
if q.Has("session_token") {
q.Del("session_token")
r.Out.URL.RawQuery = q.Encode()
}
}
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
// which is always in host:port format.
func extractClientIP(remoteAddr string) string {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return ip
}
// extractForwardedPort returns the port from the Host header if present,
// otherwise defaults to the standard port for the resolved protocol.
func extractForwardedPort(host, resolvedProto string) string {
_, port, err := net.SplitHostPort(host)
if err == nil && port != "" {
return port
}
if resolvedProto == "https" {
return "443"
}
return "80"
}
// proxyErrorHandler handles errors from the reverse proxy and serves
// user-friendly error pages instead of raw error responses.
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
requestID := getRequestID(r)
title, message, code, status := classifyProxyError(err)
web.ServeErrorPage(w, r, code, title, message, requestID, status)
}
// getRequestID retrieves the request ID from context or returns empty string.
func getRequestID(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetRequestID()
}
return ""
}
// classifyProxyError determines the appropriate error title, message, HTTP
// status code, and component status based on the error type.
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
errStr := err.Error()
switch {
case errors.Is(err, context.DeadlineExceeded):
return "Request Timeout",
"The request timed out while trying to reach the service. Please refresh the page and try again.",
http.StatusGatewayTimeout,
web.ErrorStatus{Proxy: true, Peer: true, Destination: false}
case errors.Is(err, context.Canceled):
return "Request Canceled",
"The request was canceled before it could be completed. Please refresh the page and try again.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Peer: true, Destination: false}
case errors.Is(err, roundtrip.ErrNoAccountID):
return "Configuration Error",
"The request could not be processed due to a configuration issue. Please refresh the page and try again.",
http.StatusInternalServerError,
web.ErrorStatus{Proxy: false, Peer: false, Destination: false}
case strings.Contains(errStr, "no peer connection found"),
strings.Contains(errStr, "start netbird client"),
strings.Contains(errStr, "engine not started"),
strings.Contains(errStr, "get net:"):
// The proxy peer (embedded client) is not connected
return "Proxy Not Connected",
"The proxy is not connected to the NetBird network. Please try again later or contact your administrator.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: false, Peer: false, Destination: false}
case strings.Contains(errStr, "connection refused"):
// Routing peer connected but destination service refused the connection
return "Service Unavailable",
"The connection to the service was refused. Please verify that the service is running and try again.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Peer: true, Destination: false}
case strings.Contains(errStr, "no route to host"),
strings.Contains(errStr, "network is unreachable"),
strings.Contains(errStr, "i/o timeout"):
// Peer is not reachable
return "Peer Not Connected",
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Peer: false, Destination: false}
}
// Unknown error - log it and show generic message
return "Connection Error",
"An unexpected error occurred while connecting to the service. Please try again later.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Peer: false, Destination: false}
}