[management, reverse proxy] Add reverse proxy feature (#5291)

* implement reverse proxy


---------

Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk>
Co-authored-by: mlsmaycon <mlsmaycon@gmail.com>
Co-authored-by: Eduard Gert <kontakt@eduardgert.de>
Co-authored-by: Viktor Liu <viktor@netbird.io>
Co-authored-by: Diego Noguês <diego.sure@gmail.com>
Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com>
Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com>
Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com>
Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
Pascal Fischer
2026-02-13 19:37:43 +01:00
committed by GitHub
parent edce11b34d
commit f53155562f
225 changed files with 35513 additions and 235 deletions

View File

@@ -0,0 +1,187 @@
package proxy
import (
"context"
"sync"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type requestContextKey string
const (
serviceIdKey requestContextKey = "serviceId"
accountIdKey requestContextKey = "accountId"
capturedDataKey requestContextKey = "capturedData"
)
// ResponseOrigin indicates where a response was generated.
type ResponseOrigin int
const (
// OriginBackend means the response came from the backend service.
OriginBackend ResponseOrigin = iota
// OriginNoRoute means the proxy had no matching host or path.
OriginNoRoute
// OriginProxyError means the proxy failed to reach the backend.
OriginProxyError
// OriginAuth means the proxy intercepted the request for authentication.
OriginAuth
)
func (o ResponseOrigin) String() string {
switch o {
case OriginNoRoute:
return "no_route"
case OriginProxyError:
return "proxy_error"
case OriginAuth:
return "auth"
default:
return "backend"
}
}
// CapturedData is a mutable struct that allows downstream handlers
// to pass data back up the middleware chain.
type CapturedData struct {
mu sync.RWMutex
RequestID string
ServiceId string
AccountId types.AccountID
Origin ResponseOrigin
ClientIP string
UserID string
AuthMethod string
}
// GetRequestID safely gets the request ID
func (c *CapturedData) GetRequestID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.RequestID
}
// SetServiceId safely sets the service ID
func (c *CapturedData) SetServiceId(serviceId string) {
c.mu.Lock()
defer c.mu.Unlock()
c.ServiceId = serviceId
}
// GetServiceId safely gets the service ID
func (c *CapturedData) GetServiceId() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ServiceId
}
// SetAccountId safely sets the account ID
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
c.mu.Lock()
defer c.mu.Unlock()
c.AccountId = accountId
}
// GetAccountId safely gets the account ID
func (c *CapturedData) GetAccountId() types.AccountID {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AccountId
}
// SetOrigin safely sets the response origin
func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
c.mu.Lock()
defer c.mu.Unlock()
c.Origin = origin
}
// GetOrigin safely gets the response origin
func (c *CapturedData) GetOrigin() ResponseOrigin {
c.mu.RLock()
defer c.mu.RUnlock()
return c.Origin
}
// SetClientIP safely sets the resolved client IP.
func (c *CapturedData) SetClientIP(ip string) {
c.mu.Lock()
defer c.mu.Unlock()
c.ClientIP = ip
}
// GetClientIP safely gets the resolved client IP.
func (c *CapturedData) GetClientIP() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.ClientIP
}
// SetUserID safely sets the authenticated user ID.
func (c *CapturedData) SetUserID(userID string) {
c.mu.Lock()
defer c.mu.Unlock()
c.UserID = userID
}
// GetUserID safely gets the authenticated user ID.
func (c *CapturedData) GetUserID() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.UserID
}
// SetAuthMethod safely sets the authentication method used.
func (c *CapturedData) SetAuthMethod(method string) {
c.mu.Lock()
defer c.mu.Unlock()
c.AuthMethod = method
}
// GetAuthMethod safely gets the authentication method used.
func (c *CapturedData) GetAuthMethod() string {
c.mu.RLock()
defer c.mu.RUnlock()
return c.AuthMethod
}
// WithCapturedData adds a CapturedData struct to the context
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
return context.WithValue(ctx, capturedDataKey, data)
}
// CapturedDataFromContext retrieves the CapturedData from context
func CapturedDataFromContext(ctx context.Context) *CapturedData {
v := ctx.Value(capturedDataKey)
data, ok := v.(*CapturedData)
if !ok {
return nil
}
return data
}
func withServiceId(ctx context.Context, serviceId string) context.Context {
return context.WithValue(ctx, serviceIdKey, serviceId)
}
func ServiceIdFromContext(ctx context.Context) string {
v := ctx.Value(serviceIdKey)
serviceId, ok := v.(string)
if !ok {
return ""
}
return serviceId
}
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
return context.WithValue(ctx, accountIdKey, accountId)
}
func AccountIdFromContext(ctx context.Context) types.AccountID {
v := ctx.Value(accountIdKey)
accountId, ok := v.(types.AccountID)
if !ok {
return ""
}
return accountId
}

View File

@@ -0,0 +1,130 @@
package proxy_test
import (
"crypto/rand"
"fmt"
"math/big"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"github.com/netbirdio/netbird/proxy/internal/proxy"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type nopTransport struct{}
func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: http.StatusOK,
Body: http.NoBody,
}, nil
}
func BenchmarkServeHTTP(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: map[string]*url.URL{
"/": {
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
})
req := httptest.NewRequest(http.MethodGet, "http://app.example.com", nil)
req.Host = "app.example.com"
req.RemoteAddr = "203.0.113.50:12345"
for b.Loop() {
rp.ServeHTTP(httptest.NewRecorder(), req)
}
}
func BenchmarkServeHTTPHostCount(b *testing.B) {
hostCounts := []int{1, 10, 100, 1_000, 10_000}
for _, hostCount := range hostCounts {
b.Run(fmt.Sprintf("hosts=%d", hostCount), func(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
var target string
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(hostCount)))
if err != nil {
b.Fatal(err)
}
for i := range hostCount {
id := rand.Text()
host := fmt.Sprintf("%s.example.com", id)
if int64(i) == targetIndex.Int64() {
target = id
}
rp.AddMapping(proxy.Mapping{
ID: id,
AccountID: types.AccountID(rand.Text()),
Host: host,
Paths: map[string]*url.URL{
"/": {
Scheme: "http",
Host: "10.0.0.1:8080",
},
},
})
}
req := httptest.NewRequest(http.MethodGet, "http://"+target+"/", nil)
req.Host = target
req.RemoteAddr = "203.0.113.50:12345"
for b.Loop() {
rp.ServeHTTP(httptest.NewRecorder(), req)
}
})
}
}
func BenchmarkServeHTTPPathCount(b *testing.B) {
pathCounts := []int{1, 5, 10, 25, 50}
for _, pathCount := range pathCounts {
b.Run(fmt.Sprintf("paths=%d", pathCount), func(b *testing.B) {
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
var target string
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(pathCount)))
if err != nil {
b.Fatal(err)
}
paths := make(map[string]*url.URL, pathCount)
for i := range pathCount {
path := "/" + rand.Text()
if int64(i) == targetIndex.Int64() {
target = path
}
paths[path] = &url.URL{
Scheme: "http",
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
}
}
rp.AddMapping(proxy.Mapping{
ID: rand.Text(),
AccountID: types.AccountID(rand.Text()),
Host: "app.example.com",
Paths: paths,
})
req := httptest.NewRequest(http.MethodGet, "http://app.example.com"+target, nil)
req.Host = "app.example.com"
req.RemoteAddr = "203.0.113.50:12345"
for b.Loop() {
rp.ServeHTTP(httptest.NewRecorder(), req)
}
})
}
}

View File

@@ -0,0 +1,406 @@
package proxy
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/web"
)
type ReverseProxy struct {
transport http.RoundTripper
// forwardedProto overrides the X-Forwarded-Proto header value.
// Valid values: "auto" (detect from TLS), "http", "https".
forwardedProto string
// trustedProxies is a list of IP prefixes for trusted upstream proxies.
// When the direct connection comes from a trusted proxy, forwarding
// headers are preserved and appended to instead of being stripped.
trustedProxies []netip.Prefix
mappingsMux sync.RWMutex
mappings map[string]Mapping
logger *log.Logger
}
// NewReverseProxy configures a new NetBird ReverseProxy.
// This is a wrapper around an httputil.ReverseProxy set
// to dynamically route requests based on internal mapping
// between requested URLs and targets.
// The internal mappings can be modified using the AddMapping
// and RemoveMapping functions.
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger) *ReverseProxy {
if logger == nil {
logger = log.StandardLogger()
}
return &ReverseProxy{
transport: transport,
forwardedProto: forwardedProto,
trustedProxies: trustedProxies,
mappings: make(map[string]Mapping),
logger: logger,
}
}
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
result, exists := p.findTargetForRequest(r)
if !exists {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginNoRoute)
}
requestID := getRequestID(r)
web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found",
"The requested service could not be found. Please check the URL, try refreshing, or check if the peer is running. If that doesn't work, see our documentation for help.",
requestID, web.ErrorStatus{Proxy: true, Destination: false})
return
}
// Set the serviceId in the context for later retrieval.
ctx := withServiceId(r.Context(), result.serviceID)
// Set the accountId in the context for later retrieval (for middleware).
ctx = withAccountId(ctx, result.accountID)
// Set the accountId in the context for the roundtripper to use.
ctx = roundtrip.WithAccountID(ctx, result.accountID)
// Also populate captured data if it exists (allows middleware to read after handler completes).
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
// pointer in the context, and mutate the struct here so outer middleware can read it.
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
capturedData.SetServiceId(result.serviceID)
capturedData.SetAccountId(result.accountID)
}
rp := &httputil.ReverseProxy{
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
Transport: p.transport,
ErrorHandler: proxyErrorHandler,
}
if result.rewriteRedirects {
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
}
rp.ServeHTTP(w, r.WithContext(ctx))
}
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
// inbound requests to target the backend service while setting security-relevant
// forwarding headers and stripping proxy authentication credentials.
// When passHostHeader is true, the original client Host header is preserved
// instead of being rewritten to the backend's address.
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
return func(r *httputil.ProxyRequest) {
// Strip the matched path prefix from the incoming request path before
// SetURL joins it with the target's base path, avoiding path duplication.
if matchedPath != "" && matchedPath != "/" {
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
if r.Out.URL.Path == "" {
r.Out.URL.Path = "/"
}
r.Out.URL.RawPath = ""
}
r.SetURL(target)
if passHostHeader {
r.Out.Host = r.In.Host
} else {
r.Out.Host = target.Host
}
clientIP := extractClientIP(r.In.RemoteAddr)
if IsTrustedProxy(clientIP, p.trustedProxies) {
p.setTrustedForwardingHeaders(r, clientIP)
} else {
p.setUntrustedForwardingHeaders(r, clientIP)
}
stripSessionCookie(r)
stripSessionTokenQuery(r)
}
}
// rewriteLocationFunc returns a ModifyResponse function that rewrites Location
// headers in backend responses when they point to the backend's address,
// replacing them with the public-facing host and scheme.
func (p *ReverseProxy) rewriteLocationFunc(target *url.URL, matchedPath string, inReq *http.Request) func(*http.Response) error {
publicHost := inReq.Host
publicScheme := auth.ResolveProto(p.forwardedProto, inReq.TLS)
return func(resp *http.Response) error {
location := resp.Header.Get("Location")
if location == "" {
return nil
}
locURL, err := url.Parse(location)
if err != nil {
return fmt.Errorf("parse Location header %q: %w", location, err)
}
// Only rewrite absolute URLs that point to the backend.
if locURL.Host == "" || !hostsEqual(locURL, target) {
return nil
}
locURL.Host = publicHost
locURL.Scheme = publicScheme
// Re-add the stripped path prefix so the client reaches the correct route.
// TrimRight prevents double slashes when matchedPath has a trailing slash.
if matchedPath != "" && matchedPath != "/" {
locURL.Path = strings.TrimRight(matchedPath, "/") + "/" + strings.TrimLeft(locURL.Path, "/")
}
resp.Header.Set("Location", locURL.String())
return nil
}
}
// hostsEqual compares two URL authorities, normalizing default ports per
// RFC 3986 Section 6.2.3 (https://443 == https, http://80 == http).
func hostsEqual(a, b *url.URL) bool {
return normalizeHost(a) == normalizeHost(b)
}
// normalizeHost strips the port from a URL's Host field if it matches the
// scheme's default port (443 for https, 80 for http).
func normalizeHost(u *url.URL) string {
host, port, err := net.SplitHostPort(u.Host)
if err != nil {
return u.Host
}
if (u.Scheme == "https" && port == "443") || (u.Scheme == "http" && port == "80") {
return host
}
return u.Host
}
// setTrustedForwardingHeaders appends to the existing forwarding header chain
// and preserves upstream-provided headers when the direct connection is from
// a trusted proxy.
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
// Append the direct connection IP to the existing X-Forwarded-For chain.
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
} else {
r.Out.Header.Set("X-Forwarded-For", clientIP)
}
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
if realIP := r.In.Header.Get("X-Real-IP"); realIP != "" {
r.Out.Header.Set("X-Real-IP", realIP)
} else {
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
r.Out.Header.Set("X-Real-IP", resolved)
}
// Preserve upstream X-Forwarded-Host if present.
if fwdHost := r.In.Header.Get("X-Forwarded-Host"); fwdHost != "" {
r.Out.Header.Set("X-Forwarded-Host", fwdHost)
} else {
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
}
// Trust upstream X-Forwarded-Proto; fall back to local resolution.
if fwdProto := r.In.Header.Get("X-Forwarded-Proto"); fwdProto != "" {
r.Out.Header.Set("X-Forwarded-Proto", fwdProto)
} else {
r.Out.Header.Set("X-Forwarded-Proto", auth.ResolveProto(p.forwardedProto, r.In.TLS))
}
// Trust upstream X-Forwarded-Port; fall back to local computation.
if fwdPort := r.In.Header.Get("X-Forwarded-Port"); fwdPort != "" {
r.Out.Header.Set("X-Forwarded-Port", fwdPort)
} else {
resolvedProto := r.Out.Header.Get("X-Forwarded-Proto")
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, resolvedProto))
}
}
// setUntrustedForwardingHeaders strips all incoming forwarding headers and
// sets them fresh based on the direct connection. This is the default
// behavior when no trusted proxies are configured or the direct connection
// is from an untrusted source.
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
r.Out.Header.Set("X-Forwarded-For", clientIP)
r.Out.Header.Set("X-Real-IP", clientIP)
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
r.Out.Header.Set("X-Forwarded-Proto", proto)
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
}
// stripSessionCookie removes the proxy's session cookie from the outgoing
// request while preserving all other cookies.
func stripSessionCookie(r *httputil.ProxyRequest) {
cookies := r.In.Cookies()
r.Out.Header.Del("Cookie")
for _, c := range cookies {
if c.Name != auth.SessionCookieName {
r.Out.AddCookie(c)
}
}
}
// stripSessionTokenQuery removes the OIDC session_token query parameter from
// the outgoing URL to prevent credential leakage to backends.
func stripSessionTokenQuery(r *httputil.ProxyRequest) {
q := r.Out.URL.Query()
if q.Has("session_token") {
q.Del("session_token")
r.Out.URL.RawQuery = q.Encode()
}
}
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
// which is always in host:port format.
func extractClientIP(remoteAddr string) string {
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return remoteAddr
}
return ip
}
// extractForwardedPort returns the port from the Host header if present,
// otherwise defaults to the standard port for the resolved protocol.
func extractForwardedPort(host, resolvedProto string) string {
_, port, err := net.SplitHostPort(host)
if err == nil && port != "" {
return port
}
if resolvedProto == "https" {
return "443"
}
return "80"
}
// proxyErrorHandler handles errors from the reverse proxy and serves
// user-friendly error pages instead of raw error responses.
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
if cd := CapturedDataFromContext(r.Context()); cd != nil {
cd.SetOrigin(OriginProxyError)
}
requestID := getRequestID(r)
clientIP := getClientIP(r)
title, message, code, status := classifyProxyError(err)
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
web.ServeErrorPage(w, r, code, title, message, requestID, status)
}
// getClientIP retrieves the resolved client IP from context.
func getClientIP(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetClientIP()
}
return ""
}
// getRequestID retrieves the request ID from context or returns empty string.
func getRequestID(r *http.Request) string {
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
return capturedData.GetRequestID()
}
return ""
}
// classifyProxyError determines the appropriate error title, message, HTTP
// status code, and component status based on the error type.
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
switch {
case errors.Is(err, context.DeadlineExceeded),
isNetTimeout(err):
return "Request Timeout",
"The request timed out while trying to reach the service. Please refresh the page and try again.",
http.StatusGatewayTimeout,
web.ErrorStatus{Proxy: true, Destination: false}
case errors.Is(err, context.Canceled):
return "Request Canceled",
"The request was canceled before it could be completed. Please refresh the page and try again.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
case errors.Is(err, roundtrip.ErrNoAccountID):
return "Configuration Error",
"The request could not be processed due to a configuration issue. Please refresh the page and try again.",
http.StatusInternalServerError,
web.ErrorStatus{Proxy: false, Destination: false}
case errors.Is(err, roundtrip.ErrNoPeerConnection),
errors.Is(err, roundtrip.ErrClientStartFailed):
return "Proxy Not Connected",
"The proxy is not connected to the NetBird network. Please try again later or contact your administrator.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: false, Destination: false}
case errors.Is(err, roundtrip.ErrTooManyInflight):
return "Service Overloaded",
"The service is currently handling too many requests. Please try again shortly.",
http.StatusServiceUnavailable,
web.ErrorStatus{Proxy: true, Destination: false}
case isConnectionRefused(err):
return "Service Unavailable",
"The connection to the service was refused. Please verify that the service is running and try again.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
case isHostUnreachable(err):
return "Peer Not Connected",
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
}
return "Connection Error",
"An unexpected error occurred while connecting to the service. Please try again later.",
http.StatusBadGateway,
web.ErrorStatus{Proxy: true, Destination: false}
}
// isConnectionRefused checks for connection refused errors by inspecting
// the inner error of a *net.OpError. This handles both standard net errors
// (where the inner error is a *os.SyscallError with "connection refused")
// and gVisor netstack errors ("connection was refused").
func isConnectionRefused(err error) bool {
return opErrorContains(err, "refused")
}
// isHostUnreachable checks for host/network unreachable errors by inspecting
// the inner error of a *net.OpError. Covers standard net ("no route to host",
// "network is unreachable") and gVisor ("host is unreachable", etc.).
func isHostUnreachable(err error) bool {
return opErrorContains(err, "unreachable") || opErrorContains(err, "no route to host")
}
// isNetTimeout checks whether the error is a network timeout using the
// net.Error interface.
func isNetTimeout(err error) bool {
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
// opErrorContains extracts the inner error from a *net.OpError and checks
// whether its message contains the given substring. This handles gVisor
// netstack errors which wrap tcpip errors as plain strings rather than
// syscall.Errno values.
func opErrorContains(err error, substr string) bool {
var opErr *net.OpError
if errors.As(err, &opErr) && opErr.Err != nil {
return strings.Contains(opErr.Err.Error(), substr)
}
return false
}

View File

@@ -0,0 +1,966 @@
package proxy
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/netip"
"net/url"
"os"
"syscall"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/proxy/auth"
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
"github.com/netbirdio/netbird/proxy/web"
)
func TestRewriteFunc_HostRewriting(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites host to backend by default", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
})
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
rewrite := p.rewriteFunc(target, "", true)
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
rewrite(pr)
assert.Equal(t, "public.example.com", pr.Out.Host,
"Host header should be the original client host")
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host,
"URL host (used for TLS/SNI) must still point to the backend")
})
}
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"should be set to the connecting client IP")
})
t.Run("strips spoofed X-Forwarded-For from client", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"spoofed XFF must be replaced, not appended to")
})
t.Run("strips spoofed X-Real-IP from client", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Real-IP", "10.0.0.1")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"spoofed X-Real-IP must be replaced")
})
}
func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "myapp.example.com:8443", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("auto detects https from TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("auto detects http without TLS", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
// No TLS, but forced to https
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("forced http proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "http"}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
pr.In.TLS = &tls.ConnectionState{}
rewrite(pr)
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
})
}
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
t.Run("strips nb_session cookie", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token-here"})
rewrite(pr)
cookies := pr.Out.Cookies()
for _, c := range cookies {
assert.NotEqual(t, auth.SessionCookieName, c.Name,
"proxy session cookie must not be forwarded to backend")
}
})
t.Run("preserves other cookies", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token"})
pr.In.AddCookie(&http.Cookie{Name: "app_session", Value: "app-value"})
pr.In.AddCookie(&http.Cookie{Name: "tracking", Value: "track-value"})
rewrite(pr)
cookies := pr.Out.Cookies()
cookieNames := make([]string, 0, len(cookies))
for _, c := range cookies {
cookieNames = append(cookieNames, c.Name)
}
assert.Contains(t, cookieNames, "app_session", "non-proxy cookies should be preserved")
assert.Contains(t, cookieNames, "tracking", "non-proxy cookies should be preserved")
assert.NotContains(t, cookieNames, auth.SessionCookieName, "proxy cookie must be stripped")
})
t.Run("handles request with no cookies", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
rewrite(pr)
assert.Empty(t, pr.Out.Header.Get("Cookie"))
})
}
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
p := &ReverseProxy{forwardedProto: "auto"}
rewrite := p.rewriteFunc(target, "", false)
t.Run("strips session_token query parameter", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
rewrite(pr)
assert.Empty(t, pr.Out.URL.Query().Get("session_token"),
"OIDC session token must be stripped from backend request")
assert.Equal(t, "keep", pr.Out.URL.Query().Get("other"),
"other query parameters must be preserved")
})
t.Run("preserves query when no session_token present", func(t *testing.T) {
pr := newProxyRequest(t, "http://example.com/api?foo=bar&baz=qux", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "bar", pr.Out.URL.Query().Get("foo"))
assert.Equal(t, "qux", pr.Out.URL.Query().Get("baz"))
})
}
func TestRewriteFunc_URLRewriting(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080/app")
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "http", pr.Out.URL.Scheme)
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host)
assert.Equal(t, "/app/somepath", pr.Out.URL.Path,
"SetURL should join the target base path with the request path")
})
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false)
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "https", pr.Out.URL.Scheme)
assert.Equal(t, "backend.example.org:443", pr.Out.URL.Host)
assert.Equal(t, "/app/", pr.Out.URL.Path,
"matched path prefix should be stripped before joining with target path")
})
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.org:443/app")
rewrite := p.rewriteFunc(target, "/app", false)
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/app/article/123", pr.Out.URL.Path,
"subpath after matched prefix should be preserved")
})
}
func TestExtractClientIP(t *testing.T) {
tests := []struct {
name string
remoteAddr string
expected string
}{
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
{"IPv6 with port", "[::1]:12345", "::1"},
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
{"IPv6 without brackets fallback", "::1", "::1"},
{"empty string fallback", "", ""},
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
})
}
}
func TestExtractForwardedPort(t *testing.T) {
tests := []struct {
name string
host string
resolvedProto string
expected string
}{
{"explicit port in host", "example.com:8443", "https", "8443"},
{"explicit port overrides proto default", "example.com:9090", "http", "9090"},
{"no port defaults to 443 for https", "example.com", "https", "443"},
{"no port defaults to 80 for http", "example.com", "http", "80"},
{"IPv6 host with port", "[::1]:8080", "http", "8080"},
{"IPv6 host without port", "::1", "https", "443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expected, extractForwardedPort(tt.host, tt.resolvedProto))
})
}
}
func TestRewriteFunc_TrustedProxy(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
trusted := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "203.0.113.50, 10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"))
})
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
pr.In.Header.Set("X-Real-IP", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"))
})
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"should resolve real client through trusted chain")
})
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
rewrite(pr)
assert.Equal(t, "original.example.com", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Proto", "https")
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
})
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-Port", "8443")
rewrite(pr)
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
})
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"),
"should use configured forwardedProto as fallback")
})
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"))
})
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
pr.In.Header.Set("X-Real-IP", "evil")
pr.In.Header.Set("X-Forwarded-Host", "evil.example.com")
pr.In.Header.Set("X-Forwarded-Proto", "https")
pr.In.Header.Set("X-Forwarded-Port", "9999")
rewrite(pr)
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
"untrusted: XFF must be replaced")
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
"untrusted: X-Real-IP must be replaced")
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"),
"untrusted: host must be from direct connection")
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"),
"untrusted: proto must be locally resolved")
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"),
"untrusted: port must be locally computed")
})
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
rewrite(pr)
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
"nil trusted list: should strip and use RemoteAddr")
})
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
rewrite := p.rewriteFunc(target, "", false)
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
rewrite(pr)
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
"no upstream XFF: should set direct connection IP")
})
}
// TestRewriteFunc_PathForwarding verifies what path the backend actually
// receives given different configurations. This simulates the full pipeline:
// management builds a target URL (with matching prefix baked into the path),
// then the proxy strips the prefix and SetURL re-joins with the target path.
func TestRewriteFunc_PathForwarding(t *testing.T) {
p := &ReverseProxy{forwardedProto: "auto"}
// Simulate what ToProtoMapping does: target URL includes the matching
// prefix as its path component, so the proxy strips-then-re-adds.
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
// Management builds: path="/heise", target="https://heise.de:443/heise"
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise/", pr.Out.URL.Path,
"backend sees /heise/ because prefix is stripped then re-added by SetURL")
})
t.Run("subpath under prefix also preserved", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443/heise")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise/article/123", pr.Out.URL.Path,
"subpath is preserved on top of the re-added prefix")
})
// What the behavior WOULD be if target URL had no path (true stripping)
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/", pr.Out.URL.Path,
"without path in target URL, backend sees / (true prefix stripping)")
})
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
target, _ := url.Parse("https://heise.de:443")
rewrite := p.rewriteFunc(target, "/heise", false)
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/article/123", pr.Out.URL.Path,
"without path in target URL, prefix is truly stripped")
})
// Root path "/" — no stripping expected
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
target, _ := url.Parse("https://backend.example.com:443/")
rewrite := p.rewriteFunc(target, "/", false)
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
rewrite(pr)
assert.Equal(t, "/heise", pr.Out.URL.Path,
"root path match must not strip anything")
})
}
func TestRewriteLocationFunc(t *testing.T) {
target, _ := url.Parse("http://backend.internal:8080")
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
newReq := func(rawURL string) *http.Request {
t.Helper()
r := httptest.NewRequest(http.MethodGet, rawURL, nil)
parsed, _ := url.Parse(rawURL)
r.Host = parsed.Host
return r
}
run := func(p *ReverseProxy, matchedPath string, inReq *http.Request, location string) (*http.Response, error) {
t.Helper()
modifyResp := p.rewriteLocationFunc(target, matchedPath, inReq) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
if location != "" {
resp.Header.Set("Location", location)
}
err := modifyResp(resp)
return resp, err
}
t.Run("rewrites Location pointing to backend", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
"http://backend.internal:8080/login")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
})
t.Run("does not rewrite Location pointing to other host", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"https://other.example.com/path")
require.NoError(t, err)
assert.Equal(t, "https://other.example.com/path", resp.Header.Get("Location"))
})
t.Run("does not rewrite relative Location", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"/dashboard")
require.NoError(t, err)
assert.Equal(t, "/dashboard", resp.Header.Get("Location"))
})
t.Run("re-adds stripped path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/users"), //nolint:bodyclose
"http://backend.internal:8080/users")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
})
t.Run("uses resolved proto for scheme", func(t *testing.T) {
resp, err := run(newProxy("auto"), "", newReq("http://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/path")
require.NoError(t, err)
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"))
})
t.Run("no-op when Location header is empty", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), "") //nolint:bodyclose
require.NoError(t, err)
assert.Empty(t, resp.Header.Get("Location"))
})
t.Run("does not prepend root path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/", newReq("https://public.example.com/login"), //nolint:bodyclose
"http://backend.internal:8080/login")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
})
// --- Edge cases: query parameters and fragments ---
t.Run("preserves query parameters", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/login?redirect=%2Fdashboard&lang=en")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/login?redirect=%2Fdashboard&lang=en", resp.Header.Get("Location"))
})
t.Run("preserves fragment", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/docs#section-2")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/docs#section-2", resp.Header.Get("Location"))
})
t.Run("preserves query parameters and fragment together", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/search?q=test&page=1#results")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/search?q=test&page=1#results", resp.Header.Get("Location"))
})
t.Run("preserves query parameters with path prefix re-added", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/search"), //nolint:bodyclose
"http://backend.internal:8080/search?q=hello")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/search?q=hello", resp.Header.Get("Location"))
})
// --- Edge cases: slash handling ---
t.Run("no double slash when matchedPath has trailing slash", func(t *testing.T) {
resp, err := run(newProxy("https"), "/api/", newReq("https://public.example.com/api/users"), //nolint:bodyclose
"http://backend.internal:8080/users")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
})
t.Run("backend redirect to root with path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/app", newReq("https://public.example.com/app/"), //nolint:bodyclose
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
})
t.Run("backend redirect to root with trailing-slash path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/app/", newReq("https://public.example.com/app/"), //nolint:bodyclose
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
})
t.Run("preserves trailing slash on redirect path", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/path/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path/", resp.Header.Get("Location"))
})
t.Run("backend redirect to bare root", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
"http://backend.internal:8080/")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/", resp.Header.Get("Location"))
})
// --- Edge cases: host/port matching ---
t.Run("does not rewrite when backend host matches but port differs", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:9090/other")
require.NoError(t, err)
assert.Equal(t, "http://backend.internal:9090/other", resp.Header.Get("Location"),
"Different port means different host authority, must not rewrite")
})
t.Run("rewrites when redirect omits default port matching target", func(t *testing.T) {
// Target is backend.internal:8080, redirect is to backend.internal (no port).
// These are different authorities, so should NOT rewrite.
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal/path")
require.NoError(t, err)
assert.Equal(t, "http://backend.internal/path", resp.Header.Get("Location"),
"backend.internal != backend.internal:8080, must not rewrite")
})
t.Run("rewrites when target has :443 but redirect omits it for https", func(t *testing.T) {
// Target: heise.de:443, redirect: https://heise.de/path (no :443 because it's default)
// Per RFC 3986, these are the same authority.
target443, _ := url.Parse("https://heise.de:443")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(target443, "", newReq("https://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://heise.de/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
"heise.de:443 and heise.de are the same for https")
})
t.Run("rewrites when target has :80 but redirect omits it for http", func(t *testing.T) {
target80, _ := url.Parse("http://backend.local:80")
p := newProxy("http")
modifyResp := p.rewriteLocationFunc(target80, "", newReq("http://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "http://backend.local/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"),
"backend.local:80 and backend.local are the same for http")
})
t.Run("rewrites when redirect has :443 but target omits it", func(t *testing.T) {
targetNoPort, _ := url.Parse("https://heise.de")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(targetNoPort, "", newReq("https://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://heise.de:443/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
"heise.de and heise.de:443 are the same for https")
})
t.Run("does not conflate non-default ports", func(t *testing.T) {
target8443, _ := url.Parse("https://backend.internal:8443")
p := newProxy("https")
modifyResp := p.rewriteLocationFunc(target8443, "", newReq("https://public.example.com/")) //nolint:bodyclose
resp := &http.Response{Header: http.Header{}}
resp.Header.Set("Location", "https://backend.internal/path")
err := modifyResp(resp)
require.NoError(t, err)
assert.Equal(t, "https://backend.internal/path", resp.Header.Get("Location"),
"backend.internal:8443 != backend.internal (port 443), must not rewrite")
})
// --- Edge cases: encoded paths ---
t.Run("preserves percent-encoded path segments", func(t *testing.T) {
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
"http://backend.internal:8080/path%20with%20spaces/file%2Fname")
require.NoError(t, err)
loc := resp.Header.Get("Location")
assert.Contains(t, loc, "public.example.com")
parsed, err := url.Parse(loc)
require.NoError(t, err)
assert.Equal(t, "/path with spaces/file/name", parsed.Path)
})
t.Run("preserves encoded query parameters with path prefix", func(t *testing.T) {
resp, err := run(newProxy("https"), "/v1", newReq("https://public.example.com/v1/"), //nolint:bodyclose
"http://backend.internal:8080/redirect?url=http%3A%2F%2Fexample.com")
require.NoError(t, err)
assert.Equal(t, "https://public.example.com/v1/redirect?url=http%3A%2F%2Fexample.com", resp.Header.Get("Location"))
})
}
// newProxyRequest creates an httputil.ProxyRequest suitable for testing
// the Rewrite function. It simulates what httputil.ReverseProxy does internally:
// Out is a shallow clone of In with headers copied.
func newProxyRequest(t *testing.T, rawURL, remoteAddr string) *httputil.ProxyRequest {
t.Helper()
parsed, err := url.Parse(rawURL)
require.NoError(t, err)
in := httptest.NewRequest(http.MethodGet, rawURL, nil)
in.RemoteAddr = remoteAddr
in.Host = parsed.Host
out := in.Clone(in.Context())
out.Header = in.Header.Clone()
return &httputil.ProxyRequest{In: in, Out: out}
}
func TestClassifyProxyError(t *testing.T) {
tests := []struct {
name string
err error
wantTitle string
wantCode int
wantStatus web.ErrorStatus
}{
{
name: "context deadline exceeded",
err: context.DeadlineExceeded,
wantTitle: "Request Timeout",
wantCode: http.StatusGatewayTimeout,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "wrapped deadline exceeded",
err: fmt.Errorf("dial: %w", context.DeadlineExceeded),
wantTitle: "Request Timeout",
wantCode: http.StatusGatewayTimeout,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "context canceled",
err: context.Canceled,
wantTitle: "Request Canceled",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "no account ID",
err: roundtrip.ErrNoAccountID,
wantTitle: "Configuration Error",
wantCode: http.StatusInternalServerError,
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
},
{
name: "no peer connection",
err: fmt.Errorf("%w for account: abc", roundtrip.ErrNoPeerConnection),
wantTitle: "Proxy Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
},
{
name: "client not started",
err: fmt.Errorf("%w: %w", roundtrip.ErrClientStartFailed, errors.New("engine init failed")),
wantTitle: "Proxy Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
},
{
name: "syscall ECONNREFUSED via os.SyscallError",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED},
},
wantTitle: "Service Unavailable",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "gvisor connection was refused",
err: &net.OpError{
Op: "connect",
Net: "tcp",
Err: errors.New("connection was refused"),
},
wantTitle: "Service Unavailable",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "syscall EHOSTUNREACH via os.SyscallError",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "syscall ENETUNREACH via os.SyscallError",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ENETUNREACH},
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "gvisor host is unreachable",
err: &net.OpError{
Op: "connect",
Net: "tcp",
Err: errors.New("host is unreachable"),
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "gvisor network is unreachable",
err: &net.OpError{
Op: "connect",
Net: "tcp",
Err: errors.New("network is unreachable"),
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "standard no route to host",
err: &net.OpError{
Op: "dial",
Net: "tcp",
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
},
wantTitle: "Peer Not Connected",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
{
name: "unknown error falls to default",
err: errors.New("something unexpected"),
wantTitle: "Connection Error",
wantCode: http.StatusBadGateway,
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
title, _, code, status := classifyProxyError(tt.err)
assert.Equal(t, tt.wantTitle, title, "title")
assert.Equal(t, tt.wantCode, code, "status code")
assert.Equal(t, tt.wantStatus, status, "component status")
})
}
}

View File

@@ -0,0 +1,84 @@
package proxy
import (
"net"
"net/http"
"net/url"
"sort"
"strings"
"github.com/netbirdio/netbird/proxy/internal/types"
)
type Mapping struct {
ID string
AccountID types.AccountID
Host string
Paths map[string]*url.URL
PassHostHeader bool
RewriteRedirects bool
}
type targetResult struct {
url *url.URL
matchedPath string
serviceID string
accountID types.AccountID
passHostHeader bool
rewriteRedirects bool
}
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
p.mappingsMux.RLock()
defer p.mappingsMux.RUnlock()
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
host := req.Host
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
m, exists := p.mappings[host]
if !exists {
p.logger.Debugf("no mapping found for host: %s", host)
return targetResult{}, false
}
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
paths := make([]string, 0, len(m.Paths))
for path := range m.Paths {
paths = append(paths, path)
}
sort.Slice(paths, func(i, j int) bool {
return len(paths[i]) > len(paths[j])
})
for _, path := range paths {
if strings.HasPrefix(req.URL.Path, path) {
target := m.Paths[path]
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
return targetResult{
url: target,
matchedPath: path,
serviceID: m.ID,
accountID: m.AccountID,
passHostHeader: m.PassHostHeader,
rewriteRedirects: m.RewriteRedirects,
}, true
}
}
p.logger.Debugf("no path match for host: %s, path: %s", host, req.URL.Path)
return targetResult{}, false
}
func (p *ReverseProxy) AddMapping(m Mapping) {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
p.mappings[m.Host] = m
}
func (p *ReverseProxy) RemoveMapping(m Mapping) {
p.mappingsMux.Lock()
defer p.mappingsMux.Unlock()
delete(p.mappings, m.Host)
}

View File

@@ -0,0 +1,60 @@
package proxy
import (
"net/netip"
"strings"
)
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
if len(trusted) == 0 {
return false
}
addr, err := netip.ParseAddr(ipStr)
if err != nil {
return false
}
for _, prefix := range trusted {
if prefix.Contains(addr) {
return true
}
}
return false
}
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
// It walks the XFF chain right-to-left, skipping IPs that match trusted prefixes.
// The first untrusted IP is the real client.
//
// If the trusted list is empty or remoteAddr is not trusted, it returns the
// remoteAddr IP directly (ignoring any forwarding headers).
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
remoteIP := extractClientIP(remoteAddr)
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
return remoteIP
}
if xff == "" {
return remoteIP
}
parts := strings.Split(xff, ",")
for i := len(parts) - 1; i >= 0; i-- {
ip := strings.TrimSpace(parts[i])
if ip == "" {
continue
}
if !IsTrustedProxy(ip, trusted) {
return ip
}
}
// All IPs in XFF are trusted; return the leftmost as best guess.
if first := strings.TrimSpace(parts[0]); first != "" {
return first
}
return remoteIP
}

View File

@@ -0,0 +1,129 @@
package proxy
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsTrustedProxy(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("fd00::/8"),
}
tests := []struct {
name string
ip string
trusted []netip.Prefix
want bool
}{
{"empty trusted list", "10.0.0.1", nil, false},
{"IP within /8 prefix", "10.1.2.3", trusted, true},
{"IP within /24 prefix", "192.168.1.100", trusted, true},
{"IP outside all prefixes", "203.0.113.50", trusted, false},
{"boundary IP just outside prefix", "192.168.2.1", trusted, false},
{"unparsable IP", "not-an-ip", trusted, false},
{"IPv6 in trusted range", "fd00::1", trusted, true},
{"IPv6 outside range", "2001:db8::1", trusted, false},
{"empty string", "", trusted, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, IsTrustedProxy(tt.ip, tt.trusted))
})
}
}
func TestResolveClientIP(t *testing.T) {
trusted := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("172.16.0.0/12"),
}
tests := []struct {
name string
remoteAddr string
xff string
trusted []netip.Prefix
want string
}{
{
name: "empty trusted list returns RemoteAddr",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4",
trusted: nil,
want: "203.0.113.50",
},
{
name: "untrusted RemoteAddr ignores XFF",
remoteAddr: "203.0.113.50:9999",
xff: "1.2.3.4, 10.0.0.1",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr with single client in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr walks past trusted entries in XFF",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
remoteAddr: "10.0.0.1:5000",
xff: "",
trusted: trusted,
want: "10.0.0.1",
},
{
name: "all XFF IPs trusted returns leftmost",
remoteAddr: "10.0.0.1:5000",
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
trusted: trusted,
want: "10.0.0.2",
},
{
name: "XFF with whitespace",
remoteAddr: "10.0.0.1:5000",
xff: " 203.0.113.50 , 10.0.0.2 ",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "XFF with empty segments",
remoteAddr: "10.0.0.1:5000",
xff: "203.0.113.50,,10.0.0.2",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "multi-hop with mixed trust",
remoteAddr: "10.0.0.1:5000",
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
trusted: trusted,
want: "203.0.113.50",
},
{
name: "RemoteAddr without port",
remoteAddr: "10.0.0.1",
xff: "203.0.113.50",
trusted: trusted,
want: "203.0.113.50",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, ResolveClientIP(tt.remoteAddr, tt.xff, tt.trusted))
})
}
}