Add reverse proxy header security and forwarding

- Rewrite Host header to backend target (configurable via pass_host_header per mapping)
- Strip and set X-Forwarded-For/X-Real-IP from direct connection (trust boundary)
- Set X-Forwarded-Host and X-Forwarded-Proto headers
- Strip nb_session cookie and session_token query param before forwarding
- Add --forwarded-proto flag (auto/http/https) for proto detection
- Fix OIDC redirect hardcoded https scheme
- Add pass_host_header to proto, API, and management model
This commit is contained in:
Viktor Liu
2026-02-08 14:16:52 +08:00
parent 0a3a9f977d
commit 07e59b2708
13 changed files with 700 additions and 228 deletions

View File

@@ -3,22 +3,28 @@ package proxy
import (
"context"
"errors"
"net"
"net/http"
"net/http/httputil"
"net/url"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/web"
)
type ReverseProxy struct {
transport http.RoundTripper
mappingsMux sync.RWMutex
mappings map[string]Mapping
logger *log.Logger
transport http.RoundTripper
// forwardedProto overrides the X-Forwarded-Proto header value.
// Valid values: "auto" (detect from TLS), "http", "https".
forwardedProto string
mappingsMux sync.RWMutex
mappings map[string]Mapping
logger *log.Logger
}
// NewReverseProxy configures a new NetBird ReverseProxy.
@@ -27,19 +33,20 @@ type ReverseProxy struct {
// between requested URLs and targets.
// The internal mappings can be modified using the AddMapping
// and RemoveMapping functions.
func NewReverseProxy(transport http.RoundTripper, logger *log.Logger) *ReverseProxy {
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, logger *log.Logger) *ReverseProxy {
if logger == nil {
logger = log.StandardLogger()
}
return &ReverseProxy{
transport: transport,
mappings: make(map[string]Mapping),
logger: logger,
transport: transport,
forwardedProto: forwardedProto,
mappings: make(map[string]Mapping),
logger: logger,
}
}
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
target, serviceId, accountID, exists := p.findTargetForRequest(r)
result, exists := p.findTargetForRequest(r)
if !exists {
requestID := getRequestID(r)
web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found",
@@ -49,25 +56,101 @@ func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), serviceId)
ctx := withServiceId(r.Context(), result.serviceID)
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, accountID)
ctx = withAccountId(ctx, result.accountID)
// Set the accountId in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, accountID)
ctx = roundtrip.WithAccountID(ctx, result.accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
// pointer in the context, and mutate the struct here so outer middleware can read it.
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
capturedData.SetServiceId(serviceId)
capturedData.SetAccountId(accountID)
capturedData.SetServiceId(result.serviceID)
capturedData.SetAccountId(result.accountID)
}
// Set up a reverse proxy using the transport and then use it to serve the request.
proxy := httputil.NewSingleHostReverseProxy(target)
proxy.Transport = p.transport
proxy.ErrorHandler = proxyErrorHandler
proxy.ServeHTTP(w, r.WithContext(ctx))
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(result.url, result.passHostHeader),
Transport: p.transport,
ErrorHandler: proxyErrorHandler,
}
rp.ServeHTTP(w, r.WithContext(ctx))
}
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
// inbound requests to target the backend service while setting security-relevant
// forwarding headers and stripping proxy authentication credentials.
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
func (p *ReverseProxy) rewriteFunc(target *url.URL, passHostHeader bool) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
r.SetURL(target)
if passHostHeader {
r.Out.Host = r.In.Host
} else {
r.Out.Host = target.Host
}
clientIP := extractClientIP(r.In.RemoteAddr)
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
// Strip any incoming forwarding headers since this proxy is the trust
// boundary and set them fresh based on the direct connection.
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Real-IP", clientIP)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
stripSessionCookie(r)
stripSessionTokenQuery(r)
}
}
// stripSessionCookie removes the proxy's session cookie from the outgoing
// request while preserving all other cookies.
func stripSessionCookie(r *httputil.ProxyRequest) {
cookies := r.In.Cookies()
r.Out.Header.Del("Cookie")
for _, c := range cookies {
if c.Name != auth.SessionCookieName {
r.Out.AddCookie(c)
}
}
}
// stripSessionTokenQuery removes the OIDC session_token query parameter from
// the outgoing URL to prevent credential leakage to backends.
func stripSessionTokenQuery(r *httputil.ProxyRequest) {
q := r.Out.URL.Query()
if q.Has("session_token") {
q.Del("session_token")
r.Out.URL.RawQuery = q.Encode()
}
}
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
// which is always in host:port format.
func extractClientIP(remoteAddr string) string {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return ip
}
// extractForwardedPort returns the port from the Host header if present,
// otherwise defaults to the standard port for the resolved protocol.
func extractForwardedPort(host, resolvedProto string) string {
_, port, err := net.SplitHostPort(host)
if err == nil && port != "" {
return port
}
if resolvedProto == "https" {
return "443"
}
return "80"
}
// proxyErrorHandler handles errors from the reverse proxy and serves

View File

@@ -0,0 +1,313 @@
package proxy
import (
"crypto/tls"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/auth"
)
func TestRewriteFunc_HostRewriting(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
})
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, true)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
assert.Equal(t, "public.example.com", pr.Out.Host,
"Host header should be the original client host")
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host,
"URL host (used for TLS/SNI) must still point to the backend")
})
}
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"should be set to the connecting client IP")
})
t.Run("strips spoofed X-Forwarded-For from client", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"spoofed XFF must be replaced, not appended to")
})
t.Run("strips spoofed X-Real-IP from client", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Real-IP", "10.0.0.1")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"spoofed X-Real-IP must be replaced")
})
}
func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "myapp.example.com:8443", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
})
}
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token-here"})
rewrite(pr)
cookies := pr.Out.Cookies()
for _, c := range cookies {
assert.NotEqual(t, auth.SessionCookieName, c.Name,
"proxy session cookie must not be forwarded to backend")
}
})
t.Run("preserves other cookies", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token"})
pr.In.AddCookie(&http.Cookie{Name: "app_session", Value: "app-value"})
pr.In.AddCookie(&http.Cookie{Name: "tracking", Value: "track-value"})
rewrite(pr)
cookies := pr.Out.Cookies()
cookieNames := make([]string, 0, len(cookies))
for _, c := range cookies {
cookieNames = append(cookieNames, c.Name)
}
assert.Contains(t, cookieNames, "app_session", "non-proxy cookies should be preserved")
assert.Contains(t, cookieNames, "tracking", "non-proxy cookies should be preserved")
assert.NotContains(t, cookieNames, auth.SessionCookieName, "proxy cookie must be stripped")
})
t.Run("handles request with no cookies", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get("Cookie"))
})
}
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
rewrite(pr)
assert.Empty(t, pr.Out.URL.Query().Get("session_token"),
"OIDC session token must be stripped from backend request")
assert.Equal(t, "keep", pr.Out.URL.Query().Get("other"),
"other query parameters must be preserved")
})
t.Run("preserves query when no session_token present", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/api?foo=bar&baz=qux", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "bar", pr.Out.URL.Query().Get("foo"))
assert.Equal(t, "qux", pr.Out.URL.Query().Get("baz"))
})
}
func TestRewriteFunc_URLRewriting(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, false)
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "http", pr.Out.URL.Scheme)
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host)
assert.Equal(t, "/app/somepath", pr.Out.URL.Path,
"SetURL should join the target base path with the request path")
})
}
func TestExtractClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"IPv6 with port", "[::1]:12345", "::1"},
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
{"IPv6 without brackets fallback", "::1", "::1"},
{"empty string fallback", "", ""},
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
})
}
}
func TestExtractForwardedPort(t *testing.T) {
tests := []struct {
name string
host string
resolvedProto string
expected string
}{
{"explicit port in host", "example.com:8443", "https", "8443"},
{"explicit port overrides proto default", "example.com:9090", "http", "9090"},
{"no port defaults to 443 for https", "example.com", "https", "443"},
{"no port defaults to 80 for http", "example.com", "http", "80"},
{"IPv6 host with port", "[::1]:8080", "http", "8080"},
{"IPv6 host without port", "::1", "https", "443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractForwardedPort(tt.host, tt.resolvedProto))
})
}
}
// newProxyRequest creates an httputil.ProxyRequest suitable for testing
// the Rewrite function. It simulates what httputil.ReverseProxy does internally:
// Out is a shallow clone of In with headers copied.
func newProxyRequest(t *testing.T, rawURL, remoteAddr string) *httputil.ProxyRequest {
t.Helper()
parsed, err := url.Parse(rawURL)
require.NoError(t, err)
in := httptest.NewRequest(http.MethodGet, rawURL, nil)
in.RemoteAddr = remoteAddr
in.Host = parsed.Host
out := in.Clone(in.Context())
out.Header = in.Header.Clone()
return &httputil.ProxyRequest{In: in, Out: out}
}

View File

@@ -11,22 +11,22 @@ import (
)
type Mapping struct {
ID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
ID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
PassHostHeader bool
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string, types.AccountID, bool) {
type targetResult struct {
url *url.URL
serviceID string
accountID types.AccountID
passHostHeader bool
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
p.mappingsMux.RLock()
if p.mappings == nil {
p.mappingsMux.RUnlock()
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
p.mappings = make(map[string]Mapping)
// There cannot be any loaded Mappings as we have only just initialized.
return nil, "", "", false
}
defer p.mappingsMux.RUnlock()
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
@@ -38,7 +38,7 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
p.logger.Debugf("looking for mapping for host: %s, path: %s", host, req.URL.Path)
m, exists := p.mappings[host]
if !exists {
return nil, "", "", false
return targetResult{}, false
}
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
@@ -52,18 +52,20 @@ func (p *ReverseProxy) findTargetForRequest(req *http.Request) (*url.URL, string
for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) {
return m.Paths[path], m.ID, m.AccountID, true
return targetResult{
url: m.Paths[path],
serviceID: m.ID,
accountID: m.AccountID,
passHostHeader: m.PassHostHeader,
}, true
}
}
return nil, "", "", false
return targetResult{}, false
}
func (p *ReverseProxy) AddMapping(m Mapping) {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
if p.mappings == nil {
p.mappings = make(map[string]Mapping)
}
p.mappings[m.Host] = m
}