mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 09:16:40 +00:00
[management, reverse proxy] Add reverse proxy feature (#5291)
* implement reverse proxy --------- Co-authored-by: Alisdair MacLeod <git@alisdairmacleod.co.uk> Co-authored-by: mlsmaycon <mlsmaycon@gmail.com> Co-authored-by: Eduard Gert <kontakt@eduardgert.de> Co-authored-by: Viktor Liu <viktor@netbird.io> Co-authored-by: Diego Noguês <diego.sure@gmail.com> Co-authored-by: Diego Noguês <49420+diegocn@users.noreply.github.com> Co-authored-by: Bethuel Mmbaga <bethuelmbaga12@gmail.com> Co-authored-by: Zoltan Papp <zoltan.pmail@gmail.com> Co-authored-by: Ashley Mensah <ashleyamo982@gmail.com>
This commit is contained in:
187
proxy/internal/proxy/context.go
Normal file
187
proxy/internal/proxy/context.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type requestContextKey string
|
||||
|
||||
const (
|
||||
serviceIdKey requestContextKey = "serviceId"
|
||||
accountIdKey requestContextKey = "accountId"
|
||||
capturedDataKey requestContextKey = "capturedData"
|
||||
)
|
||||
|
||||
// ResponseOrigin indicates where a response was generated.
|
||||
type ResponseOrigin int
|
||||
|
||||
const (
|
||||
// OriginBackend means the response came from the backend service.
|
||||
OriginBackend ResponseOrigin = iota
|
||||
// OriginNoRoute means the proxy had no matching host or path.
|
||||
OriginNoRoute
|
||||
// OriginProxyError means the proxy failed to reach the backend.
|
||||
OriginProxyError
|
||||
// OriginAuth means the proxy intercepted the request for authentication.
|
||||
OriginAuth
|
||||
)
|
||||
|
||||
func (o ResponseOrigin) String() string {
|
||||
switch o {
|
||||
case OriginNoRoute:
|
||||
return "no_route"
|
||||
case OriginProxyError:
|
||||
return "proxy_error"
|
||||
case OriginAuth:
|
||||
return "auth"
|
||||
default:
|
||||
return "backend"
|
||||
}
|
||||
}
|
||||
|
||||
// CapturedData is a mutable struct that allows downstream handlers
|
||||
// to pass data back up the middleware chain.
|
||||
type CapturedData struct {
|
||||
mu sync.RWMutex
|
||||
RequestID string
|
||||
ServiceId string
|
||||
AccountId types.AccountID
|
||||
Origin ResponseOrigin
|
||||
ClientIP string
|
||||
UserID string
|
||||
AuthMethod string
|
||||
}
|
||||
|
||||
// GetRequestID safely gets the request ID
|
||||
func (c *CapturedData) GetRequestID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.RequestID
|
||||
}
|
||||
|
||||
// SetServiceId safely sets the service ID
|
||||
func (c *CapturedData) SetServiceId(serviceId string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ServiceId = serviceId
|
||||
}
|
||||
|
||||
// GetServiceId safely gets the service ID
|
||||
func (c *CapturedData) GetServiceId() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ServiceId
|
||||
}
|
||||
|
||||
// SetAccountId safely sets the account ID
|
||||
func (c *CapturedData) SetAccountId(accountId types.AccountID) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AccountId = accountId
|
||||
}
|
||||
|
||||
// GetAccountId safely gets the account ID
|
||||
func (c *CapturedData) GetAccountId() types.AccountID {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AccountId
|
||||
}
|
||||
|
||||
// SetOrigin safely sets the response origin
|
||||
func (c *CapturedData) SetOrigin(origin ResponseOrigin) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.Origin = origin
|
||||
}
|
||||
|
||||
// GetOrigin safely gets the response origin
|
||||
func (c *CapturedData) GetOrigin() ResponseOrigin {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.Origin
|
||||
}
|
||||
|
||||
// SetClientIP safely sets the resolved client IP.
|
||||
func (c *CapturedData) SetClientIP(ip string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.ClientIP = ip
|
||||
}
|
||||
|
||||
// GetClientIP safely gets the resolved client IP.
|
||||
func (c *CapturedData) GetClientIP() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.ClientIP
|
||||
}
|
||||
|
||||
// SetUserID safely sets the authenticated user ID.
|
||||
func (c *CapturedData) SetUserID(userID string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.UserID = userID
|
||||
}
|
||||
|
||||
// GetUserID safely gets the authenticated user ID.
|
||||
func (c *CapturedData) GetUserID() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.UserID
|
||||
}
|
||||
|
||||
// SetAuthMethod safely sets the authentication method used.
|
||||
func (c *CapturedData) SetAuthMethod(method string) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.AuthMethod = method
|
||||
}
|
||||
|
||||
// GetAuthMethod safely gets the authentication method used.
|
||||
func (c *CapturedData) GetAuthMethod() string {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
return c.AuthMethod
|
||||
}
|
||||
|
||||
// WithCapturedData adds a CapturedData struct to the context
|
||||
func WithCapturedData(ctx context.Context, data *CapturedData) context.Context {
|
||||
return context.WithValue(ctx, capturedDataKey, data)
|
||||
}
|
||||
|
||||
// CapturedDataFromContext retrieves the CapturedData from context
|
||||
func CapturedDataFromContext(ctx context.Context) *CapturedData {
|
||||
v := ctx.Value(capturedDataKey)
|
||||
data, ok := v.(*CapturedData)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func withServiceId(ctx context.Context, serviceId string) context.Context {
|
||||
return context.WithValue(ctx, serviceIdKey, serviceId)
|
||||
}
|
||||
|
||||
func ServiceIdFromContext(ctx context.Context) string {
|
||||
v := ctx.Value(serviceIdKey)
|
||||
serviceId, ok := v.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return serviceId
|
||||
}
|
||||
func withAccountId(ctx context.Context, accountId types.AccountID) context.Context {
|
||||
return context.WithValue(ctx, accountIdKey, accountId)
|
||||
}
|
||||
|
||||
func AccountIdFromContext(ctx context.Context) types.AccountID {
|
||||
v := ctx.Value(accountIdKey)
|
||||
accountId, ok := v.(types.AccountID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return accountId
|
||||
}
|
||||
130
proxy/internal/proxy/proxy_bench_test.go
Normal file
130
proxy/internal/proxy/proxy_bench_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package proxy_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/proxy"
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type nopTransport struct{}
|
||||
|
||||
func (nopTransport) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Body: http.NoBody,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTP(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: map[string]*url.URL{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://app.example.com", nil)
|
||||
req.Host = "app.example.com"
|
||||
req.RemoteAddr = "203.0.113.50:12345"
|
||||
|
||||
for b.Loop() {
|
||||
rp.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTPHostCount(b *testing.B) {
|
||||
hostCounts := []int{1, 10, 100, 1_000, 10_000}
|
||||
|
||||
for _, hostCount := range hostCounts {
|
||||
b.Run(fmt.Sprintf("hosts=%d", hostCount), func(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
|
||||
var target string
|
||||
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(hostCount)))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
for i := range hostCount {
|
||||
id := rand.Text()
|
||||
host := fmt.Sprintf("%s.example.com", id)
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = id
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: id,
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: host,
|
||||
Paths: map[string]*url.URL{
|
||||
"/": {
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:8080",
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://"+target+"/", nil)
|
||||
req.Host = target
|
||||
req.RemoteAddr = "203.0.113.50:12345"
|
||||
|
||||
for b.Loop() {
|
||||
rp.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkServeHTTPPathCount(b *testing.B) {
|
||||
pathCounts := []int{1, 5, 10, 25, 50}
|
||||
|
||||
for _, pathCount := range pathCounts {
|
||||
b.Run(fmt.Sprintf("paths=%d", pathCount), func(b *testing.B) {
|
||||
rp := proxy.NewReverseProxy(nopTransport{}, "http", nil, nil)
|
||||
|
||||
var target string
|
||||
targetIndex, err := rand.Int(rand.Reader, big.NewInt(int64(pathCount)))
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
paths := make(map[string]*url.URL, pathCount)
|
||||
for i := range pathCount {
|
||||
path := "/" + rand.Text()
|
||||
if int64(i) == targetIndex.Int64() {
|
||||
target = path
|
||||
}
|
||||
paths[path] = &url.URL{
|
||||
Scheme: "http",
|
||||
Host: "10.0.0.1:" + fmt.Sprintf("%d", 8080+i),
|
||||
}
|
||||
}
|
||||
rp.AddMapping(proxy.Mapping{
|
||||
ID: rand.Text(),
|
||||
AccountID: types.AccountID(rand.Text()),
|
||||
Host: "app.example.com",
|
||||
Paths: paths,
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "http://app.example.com"+target, nil)
|
||||
req.Host = "app.example.com"
|
||||
req.RemoteAddr = "203.0.113.50:12345"
|
||||
|
||||
for b.Loop() {
|
||||
rp.ServeHTTP(httptest.NewRecorder(), req)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
406
proxy/internal/proxy/reverseproxy.go
Normal file
406
proxy/internal/proxy/reverseproxy.go
Normal file
@@ -0,0 +1,406 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
type ReverseProxy struct {
|
||||
transport http.RoundTripper
|
||||
// forwardedProto overrides the X-Forwarded-Proto header value.
|
||||
// Valid values: "auto" (detect from TLS), "http", "https".
|
||||
forwardedProto string
|
||||
// trustedProxies is a list of IP prefixes for trusted upstream proxies.
|
||||
// When the direct connection comes from a trusted proxy, forwarding
|
||||
// headers are preserved and appended to instead of being stripped.
|
||||
trustedProxies []netip.Prefix
|
||||
mappingsMux sync.RWMutex
|
||||
mappings map[string]Mapping
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewReverseProxy configures a new NetBird ReverseProxy.
|
||||
// This is a wrapper around an httputil.ReverseProxy set
|
||||
// to dynamically route requests based on internal mapping
|
||||
// between requested URLs and targets.
|
||||
// The internal mappings can be modified using the AddMapping
|
||||
// and RemoveMapping functions.
|
||||
func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger) *ReverseProxy {
|
||||
if logger == nil {
|
||||
logger = log.StandardLogger()
|
||||
}
|
||||
return &ReverseProxy{
|
||||
transport: transport,
|
||||
forwardedProto: forwardedProto,
|
||||
trustedProxies: trustedProxies,
|
||||
mappings: make(map[string]Mapping),
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
result, exists := p.findTargetForRequest(r)
|
||||
if !exists {
|
||||
if cd := CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(OriginNoRoute)
|
||||
}
|
||||
requestID := getRequestID(r)
|
||||
web.ServeErrorPage(w, r, http.StatusNotFound, "Service Not Found",
|
||||
"The requested service could not be found. Please check the URL, try refreshing, or check if the peer is running. If that doesn't work, see our documentation for help.",
|
||||
requestID, web.ErrorStatus{Proxy: true, Destination: false})
|
||||
return
|
||||
}
|
||||
|
||||
// Set the serviceId in the context for later retrieval.
|
||||
ctx := withServiceId(r.Context(), result.serviceID)
|
||||
// Set the accountId in the context for later retrieval (for middleware).
|
||||
ctx = withAccountId(ctx, result.accountID)
|
||||
// Set the accountId in the context for the roundtripper to use.
|
||||
ctx = roundtrip.WithAccountID(ctx, result.accountID)
|
||||
|
||||
// Also populate captured data if it exists (allows middleware to read after handler completes).
|
||||
// This solves the problem of passing data UP the middleware chain: we put a mutable struct
|
||||
// pointer in the context, and mutate the struct here so outer middleware can read it.
|
||||
if capturedData := CapturedDataFromContext(ctx); capturedData != nil {
|
||||
capturedData.SetServiceId(result.serviceID)
|
||||
capturedData.SetAccountId(result.accountID)
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
Rewrite: p.rewriteFunc(result.url, result.matchedPath, result.passHostHeader),
|
||||
Transport: p.transport,
|
||||
ErrorHandler: proxyErrorHandler,
|
||||
}
|
||||
if result.rewriteRedirects {
|
||||
rp.ModifyResponse = p.rewriteLocationFunc(result.url, result.matchedPath, r) //nolint:bodyclose
|
||||
}
|
||||
rp.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
// rewriteFunc returns a Rewrite function for httputil.ReverseProxy that rewrites
|
||||
// inbound requests to target the backend service while setting security-relevant
|
||||
// forwarding headers and stripping proxy authentication credentials.
|
||||
// When passHostHeader is true, the original client Host header is preserved
|
||||
// instead of being rewritten to the backend's address.
|
||||
func (p *ReverseProxy) rewriteFunc(target *url.URL, matchedPath string, passHostHeader bool) func(r *httputil.ProxyRequest) {
|
||||
return func(r *httputil.ProxyRequest) {
|
||||
// Strip the matched path prefix from the incoming request path before
|
||||
// SetURL joins it with the target's base path, avoiding path duplication.
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
r.Out.URL.Path = strings.TrimPrefix(r.Out.URL.Path, matchedPath)
|
||||
if r.Out.URL.Path == "" {
|
||||
r.Out.URL.Path = "/"
|
||||
}
|
||||
r.Out.URL.RawPath = ""
|
||||
}
|
||||
|
||||
r.SetURL(target)
|
||||
if passHostHeader {
|
||||
r.Out.Host = r.In.Host
|
||||
} else {
|
||||
r.Out.Host = target.Host
|
||||
}
|
||||
|
||||
clientIP := extractClientIP(r.In.RemoteAddr)
|
||||
|
||||
if IsTrustedProxy(clientIP, p.trustedProxies) {
|
||||
p.setTrustedForwardingHeaders(r, clientIP)
|
||||
} else {
|
||||
p.setUntrustedForwardingHeaders(r, clientIP)
|
||||
}
|
||||
|
||||
stripSessionCookie(r)
|
||||
stripSessionTokenQuery(r)
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteLocationFunc returns a ModifyResponse function that rewrites Location
|
||||
// headers in backend responses when they point to the backend's address,
|
||||
// replacing them with the public-facing host and scheme.
|
||||
func (p *ReverseProxy) rewriteLocationFunc(target *url.URL, matchedPath string, inReq *http.Request) func(*http.Response) error {
|
||||
publicHost := inReq.Host
|
||||
publicScheme := auth.ResolveProto(p.forwardedProto, inReq.TLS)
|
||||
|
||||
return func(resp *http.Response) error {
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
locURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse Location header %q: %w", location, err)
|
||||
}
|
||||
|
||||
// Only rewrite absolute URLs that point to the backend.
|
||||
if locURL.Host == "" || !hostsEqual(locURL, target) {
|
||||
return nil
|
||||
}
|
||||
|
||||
locURL.Host = publicHost
|
||||
locURL.Scheme = publicScheme
|
||||
|
||||
// Re-add the stripped path prefix so the client reaches the correct route.
|
||||
// TrimRight prevents double slashes when matchedPath has a trailing slash.
|
||||
if matchedPath != "" && matchedPath != "/" {
|
||||
locURL.Path = strings.TrimRight(matchedPath, "/") + "/" + strings.TrimLeft(locURL.Path, "/")
|
||||
}
|
||||
|
||||
resp.Header.Set("Location", locURL.String())
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// hostsEqual compares two URL authorities, normalizing default ports per
|
||||
// RFC 3986 Section 6.2.3 (https://443 == https, http://80 == http).
|
||||
func hostsEqual(a, b *url.URL) bool {
|
||||
return normalizeHost(a) == normalizeHost(b)
|
||||
}
|
||||
|
||||
// normalizeHost strips the port from a URL's Host field if it matches the
|
||||
// scheme's default port (443 for https, 80 for http).
|
||||
func normalizeHost(u *url.URL) string {
|
||||
host, port, err := net.SplitHostPort(u.Host)
|
||||
if err != nil {
|
||||
return u.Host
|
||||
}
|
||||
if (u.Scheme == "https" && port == "443") || (u.Scheme == "http" && port == "80") {
|
||||
return host
|
||||
}
|
||||
return u.Host
|
||||
}
|
||||
|
||||
// setTrustedForwardingHeaders appends to the existing forwarding header chain
|
||||
// and preserves upstream-provided headers when the direct connection is from
|
||||
// a trusted proxy.
|
||||
func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
// Append the direct connection IP to the existing X-Forwarded-For chain.
|
||||
if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" {
|
||||
r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
}
|
||||
|
||||
// Preserve upstream X-Real-IP if present; otherwise resolve through the chain.
|
||||
if realIP := r.In.Header.Get("X-Real-IP"); realIP != "" {
|
||||
r.Out.Header.Set("X-Real-IP", realIP)
|
||||
} else {
|
||||
resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies)
|
||||
r.Out.Header.Set("X-Real-IP", resolved)
|
||||
}
|
||||
|
||||
// Preserve upstream X-Forwarded-Host if present.
|
||||
if fwdHost := r.In.Header.Get("X-Forwarded-Host"); fwdHost != "" {
|
||||
r.Out.Header.Set("X-Forwarded-Host", fwdHost)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
}
|
||||
|
||||
// Trust upstream X-Forwarded-Proto; fall back to local resolution.
|
||||
if fwdProto := r.In.Header.Get("X-Forwarded-Proto"); fwdProto != "" {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", fwdProto)
|
||||
} else {
|
||||
r.Out.Header.Set("X-Forwarded-Proto", auth.ResolveProto(p.forwardedProto, r.In.TLS))
|
||||
}
|
||||
|
||||
// Trust upstream X-Forwarded-Port; fall back to local computation.
|
||||
if fwdPort := r.In.Header.Get("X-Forwarded-Port"); fwdPort != "" {
|
||||
r.Out.Header.Set("X-Forwarded-Port", fwdPort)
|
||||
} else {
|
||||
resolvedProto := r.Out.Header.Get("X-Forwarded-Proto")
|
||||
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, resolvedProto))
|
||||
}
|
||||
}
|
||||
|
||||
// setUntrustedForwardingHeaders strips all incoming forwarding headers and
|
||||
// sets them fresh based on the direct connection. This is the default
|
||||
// behavior when no trusted proxies are configured or the direct connection
|
||||
// is from an untrusted source.
|
||||
func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) {
|
||||
proto := auth.ResolveProto(p.forwardedProto, r.In.TLS)
|
||||
r.Out.Header.Set("X-Forwarded-For", clientIP)
|
||||
r.Out.Header.Set("X-Real-IP", clientIP)
|
||||
r.Out.Header.Set("X-Forwarded-Host", r.In.Host)
|
||||
r.Out.Header.Set("X-Forwarded-Proto", proto)
|
||||
r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto))
|
||||
}
|
||||
|
||||
// stripSessionCookie removes the proxy's session cookie from the outgoing
|
||||
// request while preserving all other cookies.
|
||||
func stripSessionCookie(r *httputil.ProxyRequest) {
|
||||
cookies := r.In.Cookies()
|
||||
r.Out.Header.Del("Cookie")
|
||||
for _, c := range cookies {
|
||||
if c.Name != auth.SessionCookieName {
|
||||
r.Out.AddCookie(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// stripSessionTokenQuery removes the OIDC session_token query parameter from
|
||||
// the outgoing URL to prevent credential leakage to backends.
|
||||
func stripSessionTokenQuery(r *httputil.ProxyRequest) {
|
||||
q := r.Out.URL.Query()
|
||||
if q.Has("session_token") {
|
||||
q.Del("session_token")
|
||||
r.Out.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// extractClientIP extracts the IP address from an http.Request.RemoteAddr
|
||||
// which is always in host:port format.
|
||||
func extractClientIP(remoteAddr string) string {
|
||||
ip, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
return remoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// extractForwardedPort returns the port from the Host header if present,
|
||||
// otherwise defaults to the standard port for the resolved protocol.
|
||||
func extractForwardedPort(host, resolvedProto string) string {
|
||||
_, port, err := net.SplitHostPort(host)
|
||||
if err == nil && port != "" {
|
||||
return port
|
||||
}
|
||||
if resolvedProto == "https" {
|
||||
return "443"
|
||||
}
|
||||
return "80"
|
||||
}
|
||||
|
||||
// proxyErrorHandler handles errors from the reverse proxy and serves
|
||||
// user-friendly error pages instead of raw error responses.
|
||||
func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) {
|
||||
if cd := CapturedDataFromContext(r.Context()); cd != nil {
|
||||
cd.SetOrigin(OriginProxyError)
|
||||
}
|
||||
requestID := getRequestID(r)
|
||||
clientIP := getClientIP(r)
|
||||
title, message, code, status := classifyProxyError(err)
|
||||
|
||||
log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v",
|
||||
requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err)
|
||||
|
||||
web.ServeErrorPage(w, r, code, title, message, requestID, status)
|
||||
}
|
||||
|
||||
// getClientIP retrieves the resolved client IP from context.
|
||||
func getClientIP(r *http.Request) string {
|
||||
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
|
||||
return capturedData.GetClientIP()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// getRequestID retrieves the request ID from context or returns empty string.
|
||||
func getRequestID(r *http.Request) string {
|
||||
if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil {
|
||||
return capturedData.GetRequestID()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// classifyProxyError determines the appropriate error title, message, HTTP
|
||||
// status code, and component status based on the error type.
|
||||
func classifyProxyError(err error) (title, message string, code int, status web.ErrorStatus) {
|
||||
switch {
|
||||
case errors.Is(err, context.DeadlineExceeded),
|
||||
isNetTimeout(err):
|
||||
return "Request Timeout",
|
||||
"The request timed out while trying to reach the service. Please refresh the page and try again.",
|
||||
http.StatusGatewayTimeout,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case errors.Is(err, context.Canceled):
|
||||
return "Request Canceled",
|
||||
"The request was canceled before it could be completed. Please refresh the page and try again.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrNoAccountID):
|
||||
return "Configuration Error",
|
||||
"The request could not be processed due to a configuration issue. Please refresh the page and try again.",
|
||||
http.StatusInternalServerError,
|
||||
web.ErrorStatus{Proxy: false, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrNoPeerConnection),
|
||||
errors.Is(err, roundtrip.ErrClientStartFailed):
|
||||
return "Proxy Not Connected",
|
||||
"The proxy is not connected to the NetBird network. Please try again later or contact your administrator.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: false, Destination: false}
|
||||
|
||||
case errors.Is(err, roundtrip.ErrTooManyInflight):
|
||||
return "Service Overloaded",
|
||||
"The service is currently handling too many requests. Please try again shortly.",
|
||||
http.StatusServiceUnavailable,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isConnectionRefused(err):
|
||||
return "Service Unavailable",
|
||||
"The connection to the service was refused. Please verify that the service is running and try again.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
|
||||
case isHostUnreachable(err):
|
||||
return "Peer Not Connected",
|
||||
"The connection to the peer could not be established. Please ensure the peer is running and connected to the NetBird network.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
}
|
||||
|
||||
return "Connection Error",
|
||||
"An unexpected error occurred while connecting to the service. Please try again later.",
|
||||
http.StatusBadGateway,
|
||||
web.ErrorStatus{Proxy: true, Destination: false}
|
||||
}
|
||||
|
||||
// isConnectionRefused checks for connection refused errors by inspecting
|
||||
// the inner error of a *net.OpError. This handles both standard net errors
|
||||
// (where the inner error is a *os.SyscallError with "connection refused")
|
||||
// and gVisor netstack errors ("connection was refused").
|
||||
func isConnectionRefused(err error) bool {
|
||||
return opErrorContains(err, "refused")
|
||||
}
|
||||
|
||||
// isHostUnreachable checks for host/network unreachable errors by inspecting
|
||||
// the inner error of a *net.OpError. Covers standard net ("no route to host",
|
||||
// "network is unreachable") and gVisor ("host is unreachable", etc.).
|
||||
func isHostUnreachable(err error) bool {
|
||||
return opErrorContains(err, "unreachable") || opErrorContains(err, "no route to host")
|
||||
}
|
||||
|
||||
// isNetTimeout checks whether the error is a network timeout using the
|
||||
// net.Error interface.
|
||||
func isNetTimeout(err error) bool {
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
||||
// opErrorContains extracts the inner error from a *net.OpError and checks
|
||||
// whether its message contains the given substring. This handles gVisor
|
||||
// netstack errors which wrap tcpip errors as plain strings rather than
|
||||
// syscall.Errno values.
|
||||
func opErrorContains(err error, substr string) bool {
|
||||
var opErr *net.OpError
|
||||
if errors.As(err, &opErr) && opErr.Err != nil {
|
||||
return strings.Contains(opErr.Err.Error(), substr)
|
||||
}
|
||||
return false
|
||||
}
|
||||
966
proxy/internal/proxy/reverseproxy_test.go
Normal file
966
proxy/internal/proxy/reverseproxy_test.go
Normal file
@@ -0,0 +1,966 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"os"
|
||||
"syscall"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/auth"
|
||||
"github.com/netbirdio/netbird/proxy/internal/roundtrip"
|
||||
"github.com/netbirdio/netbird/proxy/web"
|
||||
)
|
||||
|
||||
func TestRewriteFunc_HostRewriting(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites host to backend by default", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.Host)
|
||||
})
|
||||
|
||||
t.Run("preserves original host when passHostHeader is true", func(t *testing.T) {
|
||||
rewrite := p.rewriteFunc(target, "", true)
|
||||
pr := newProxyRequest(t, "https://public.example.com/path", "203.0.113.1:12345")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "public.example.com", pr.Out.Host,
|
||||
"Host header should be the original client host")
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host,
|
||||
"URL host (used for TLS/SNI) must still point to the backend")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_XForwardedForStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
t.Run("sets X-Forwarded-For from direct connection IP", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"should be set to the connecting client IP")
|
||||
})
|
||||
|
||||
t.Run("strips spoofed X-Forwarded-For from client", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"spoofed XFF must be replaced, not appended to")
|
||||
})
|
||||
|
||||
t.Run("strips spoofed X-Real-IP from client", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Real-IP", "10.0.0.1")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
|
||||
"spoofed X-Real-IP must be replaced")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_ForwardedHostAndProto(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
|
||||
t.Run("sets X-Forwarded-Host to original host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://myapp.example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "myapp.example.com:8443", pr.Out.Header.Get("X-Forwarded-Host"))
|
||||
})
|
||||
|
||||
t.Run("sets X-Forwarded-Port from explicit host port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com:8443/path", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 443 for https", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("defaults X-Forwarded-Port to 80 for http", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("auto detects https from TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("auto detects http without TLS", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("forced proto overrides TLS detection", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
// No TLS, but forced to https
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("forced http proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "http"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "https://example.com/", "1.2.3.4:5000")
|
||||
pr.In.TLS = &tls.ConnectionState{}
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_SessionCookieStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
t.Run("strips nb_session cookie", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token-here"})
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
cookies := pr.Out.Cookies()
|
||||
for _, c := range cookies {
|
||||
assert.NotEqual(t, auth.SessionCookieName, c.Name,
|
||||
"proxy session cookie must not be forwarded to backend")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("preserves other cookies", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
pr.In.AddCookie(&http.Cookie{Name: auth.SessionCookieName, Value: "jwt-token"})
|
||||
pr.In.AddCookie(&http.Cookie{Name: "app_session", Value: "app-value"})
|
||||
pr.In.AddCookie(&http.Cookie{Name: "tracking", Value: "track-value"})
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
cookies := pr.Out.Cookies()
|
||||
cookieNames := make([]string, 0, len(cookies))
|
||||
for _, c := range cookies {
|
||||
cookieNames = append(cookieNames, c.Name)
|
||||
}
|
||||
assert.Contains(t, cookieNames, "app_session", "non-proxy cookies should be preserved")
|
||||
assert.Contains(t, cookieNames, "tracking", "non-proxy cookies should be preserved")
|
||||
assert.NotContains(t, cookieNames, auth.SessionCookieName, "proxy cookie must be stripped")
|
||||
})
|
||||
|
||||
t.Run("handles request with no cookies", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.Header.Get("Cookie"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_SessionTokenQueryStripping(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
t.Run("strips session_token query parameter", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/callback?session_token=secret123&other=keep", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Empty(t, pr.Out.URL.Query().Get("session_token"),
|
||||
"OIDC session token must be stripped from backend request")
|
||||
assert.Equal(t, "keep", pr.Out.URL.Query().Get("other"),
|
||||
"other query parameters must be preserved")
|
||||
})
|
||||
|
||||
t.Run("preserves query when no session_token present", func(t *testing.T) {
|
||||
pr := newProxyRequest(t, "http://example.com/api?foo=bar&baz=qux", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "bar", pr.Out.URL.Query().Get("foo"))
|
||||
assert.Equal(t, "qux", pr.Out.URL.Query().Get("baz"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteFunc_URLRewriting(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
t.Run("rewrites URL to target with path prefix", func(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080/app")
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
pr := newProxyRequest(t, "http://example.com/somepath", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "http", pr.Out.URL.Scheme)
|
||||
assert.Equal(t, "backend.internal:8080", pr.Out.URL.Host)
|
||||
assert.Equal(t, "/app/somepath", pr.Out.URL.Path,
|
||||
"SetURL should join the target base path with the request path")
|
||||
})
|
||||
|
||||
t.Run("strips matched path prefix to avoid duplication", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
pr := newProxyRequest(t, "http://example.com/app", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.URL.Scheme)
|
||||
assert.Equal(t, "backend.example.org:443", pr.Out.URL.Host)
|
||||
assert.Equal(t, "/app/", pr.Out.URL.Path,
|
||||
"matched path prefix should be stripped before joining with target path")
|
||||
})
|
||||
|
||||
t.Run("strips matched prefix and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.org:443/app")
|
||||
rewrite := p.rewriteFunc(target, "/app", false)
|
||||
pr := newProxyRequest(t, "http://example.com/app/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/app/article/123", pr.Out.URL.Path,
|
||||
"subpath after matched prefix should be preserved")
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractClientIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
expected string
|
||||
}{
|
||||
{"IPv4 with port", "192.168.1.1:12345", "192.168.1.1"},
|
||||
{"IPv6 with port", "[::1]:12345", "::1"},
|
||||
{"IPv6 full with port", "[2001:db8::1]:443", "2001:db8::1"},
|
||||
{"IPv4 without port fallback", "192.168.1.1", "192.168.1.1"},
|
||||
{"IPv6 without brackets fallback", "::1", "::1"},
|
||||
{"empty string fallback", "", ""},
|
||||
{"public IP", "203.0.113.50:9999", "203.0.113.50"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractClientIP(tt.remoteAddr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractForwardedPort(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
resolvedProto string
|
||||
expected string
|
||||
}{
|
||||
{"explicit port in host", "example.com:8443", "https", "8443"},
|
||||
{"explicit port overrides proto default", "example.com:9090", "http", "9090"},
|
||||
{"no port defaults to 443 for https", "example.com", "https", "443"},
|
||||
{"no port defaults to 80 for http", "example.com", "http", "80"},
|
||||
{"IPv6 host with port", "[::1]:8080", "http", "8080"},
|
||||
{"IPv6 host without port", "::1", "https", "443"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expected, extractForwardedPort(tt.host, tt.resolvedProto))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteFunc_TrustedProxy(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
trusted := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}
|
||||
|
||||
t.Run("appends to X-Forwarded-For", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50, 10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"))
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Real-IP", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
pr.In.Header.Set("X-Real-IP", "203.0.113.50")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"))
|
||||
})
|
||||
|
||||
t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
|
||||
"should resolve real client through trusted chain")
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Host", "original.example.com")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "original.example.com", pr.Out.Header.Get("X-Forwarded-Host"))
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"))
|
||||
})
|
||||
|
||||
t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-Port", "8443")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port"))
|
||||
})
|
||||
|
||||
t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"),
|
||||
"should use configured forwardedProto as fallback")
|
||||
})
|
||||
|
||||
t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"))
|
||||
})
|
||||
|
||||
t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999")
|
||||
pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1")
|
||||
pr.In.Header.Set("X-Real-IP", "evil")
|
||||
pr.In.Header.Set("X-Forwarded-Host", "evil.example.com")
|
||||
pr.In.Header.Set("X-Forwarded-Proto", "https")
|
||||
pr.In.Header.Set("X-Forwarded-Port", "9999")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"untrusted: XFF must be replaced")
|
||||
assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"),
|
||||
"untrusted: X-Real-IP must be replaced")
|
||||
assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"),
|
||||
"untrusted: host must be from direct connection")
|
||||
assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"),
|
||||
"untrusted: proto must be locally resolved")
|
||||
assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"),
|
||||
"untrusted: port must be locally computed")
|
||||
})
|
||||
|
||||
t.Run("empty trusted list behaves as untrusted", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
pr.In.Header.Set("X-Forwarded-For", "203.0.113.50")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"nil trusted list: should strip and use RemoteAddr")
|
||||
})
|
||||
|
||||
t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted}
|
||||
rewrite := p.rewriteFunc(target, "", false)
|
||||
|
||||
pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"),
|
||||
"no upstream XFF: should set direct connection IP")
|
||||
})
|
||||
}
|
||||
|
||||
// TestRewriteFunc_PathForwarding verifies what path the backend actually
|
||||
// receives given different configurations. This simulates the full pipeline:
|
||||
// management builds a target URL (with matching prefix baked into the path),
|
||||
// then the proxy strips the prefix and SetURL re-joins with the target path.
|
||||
func TestRewriteFunc_PathForwarding(t *testing.T) {
|
||||
p := &ReverseProxy{forwardedProto: "auto"}
|
||||
|
||||
// Simulate what ToProtoMapping does: target URL includes the matching
|
||||
// prefix as its path component, so the proxy strips-then-re-adds.
|
||||
t.Run("path prefix baked into target URL is a no-op", func(t *testing.T) {
|
||||
// Management builds: path="/heise", target="https://heise.de:443/heise"
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/heise/", pr.Out.URL.Path,
|
||||
"backend sees /heise/ because prefix is stripped then re-added by SetURL")
|
||||
})
|
||||
|
||||
t.Run("subpath under prefix also preserved", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443/heise")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/heise/article/123", pr.Out.URL.Path,
|
||||
"subpath is preserved on top of the re-added prefix")
|
||||
})
|
||||
|
||||
// What the behavior WOULD be if target URL had no path (true stripping)
|
||||
t.Run("target without path prefix gives true stripping", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/", pr.Out.URL.Path,
|
||||
"without path in target URL, backend sees / (true prefix stripping)")
|
||||
})
|
||||
|
||||
t.Run("target without path prefix strips and preserves subpath", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://heise.de:443")
|
||||
rewrite := p.rewriteFunc(target, "/heise", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise/article/123", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/article/123", pr.Out.URL.Path,
|
||||
"without path in target URL, prefix is truly stripped")
|
||||
})
|
||||
|
||||
// Root path "/" — no stripping expected
|
||||
t.Run("root path forwards full request path unchanged", func(t *testing.T) {
|
||||
target, _ := url.Parse("https://backend.example.com:443/")
|
||||
rewrite := p.rewriteFunc(target, "/", false)
|
||||
pr := newProxyRequest(t, "http://external.test/heise", "1.2.3.4:5000")
|
||||
|
||||
rewrite(pr)
|
||||
|
||||
assert.Equal(t, "/heise", pr.Out.URL.Path,
|
||||
"root path match must not strip anything")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRewriteLocationFunc(t *testing.T) {
|
||||
target, _ := url.Parse("http://backend.internal:8080")
|
||||
newProxy := func(proto string) *ReverseProxy { return &ReverseProxy{forwardedProto: proto} }
|
||||
newReq := func(rawURL string) *http.Request {
|
||||
t.Helper()
|
||||
r := httptest.NewRequest(http.MethodGet, rawURL, nil)
|
||||
parsed, _ := url.Parse(rawURL)
|
||||
r.Host = parsed.Host
|
||||
return r
|
||||
}
|
||||
run := func(p *ReverseProxy, matchedPath string, inReq *http.Request, location string) (*http.Response, error) {
|
||||
t.Helper()
|
||||
modifyResp := p.rewriteLocationFunc(target, matchedPath, inReq) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
if location != "" {
|
||||
resp.Header.Set("Location", location)
|
||||
}
|
||||
err := modifyResp(resp)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
t.Run("rewrites Location pointing to backend", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/login")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("does not rewrite Location pointing to other host", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"https://other.example.com/path")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://other.example.com/path", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("does not rewrite relative Location", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"/dashboard")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/dashboard", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("re-adds stripped path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/users"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/users")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("uses resolved proto for scheme", func(t *testing.T) {
|
||||
resp, err := run(newProxy("auto"), "", newReq("http://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/path")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("no-op when Location header is empty", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), "") //nolint:bodyclose
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("does not prepend root path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/", newReq("https://public.example.com/login"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/login")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/login", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
// --- Edge cases: query parameters and fragments ---
|
||||
|
||||
t.Run("preserves query parameters", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/login?redirect=%2Fdashboard&lang=en")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/login?redirect=%2Fdashboard&lang=en", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves fragment", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/docs#section-2")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/docs#section-2", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves query parameters and fragment together", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/search?q=test&page=1#results")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/search?q=test&page=1#results", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves query parameters with path prefix re-added", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/api", newReq("https://public.example.com/api/search"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/search?q=hello")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/api/search?q=hello", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
// --- Edge cases: slash handling ---
|
||||
|
||||
t.Run("no double slash when matchedPath has trailing slash", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/api/", newReq("https://public.example.com/api/users"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/users")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/api/users", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("backend redirect to root with path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/app", newReq("https://public.example.com/app/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("backend redirect to root with trailing-slash path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/app/", newReq("https://public.example.com/app/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/app/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("preserves trailing slash on redirect path", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/path/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/path/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
t.Run("backend redirect to bare root", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/page"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/", resp.Header.Get("Location"))
|
||||
})
|
||||
|
||||
// --- Edge cases: host/port matching ---
|
||||
|
||||
t.Run("does not rewrite when backend host matches but port differs", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:9090/other")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://backend.internal:9090/other", resp.Header.Get("Location"),
|
||||
"Different port means different host authority, must not rewrite")
|
||||
})
|
||||
|
||||
t.Run("rewrites when redirect omits default port matching target", func(t *testing.T) {
|
||||
// Target is backend.internal:8080, redirect is to backend.internal (no port).
|
||||
// These are different authorities, so should NOT rewrite.
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal/path")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://backend.internal/path", resp.Header.Get("Location"),
|
||||
"backend.internal != backend.internal:8080, must not rewrite")
|
||||
})
|
||||
|
||||
t.Run("rewrites when target has :443 but redirect omits it for https", func(t *testing.T) {
|
||||
// Target: heise.de:443, redirect: https://heise.de/path (no :443 because it's default)
|
||||
// Per RFC 3986, these are the same authority.
|
||||
target443, _ := url.Parse("https://heise.de:443")
|
||||
p := newProxy("https")
|
||||
modifyResp := p.rewriteLocationFunc(target443, "", newReq("https://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "https://heise.de/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
|
||||
"heise.de:443 and heise.de are the same for https")
|
||||
})
|
||||
|
||||
t.Run("rewrites when target has :80 but redirect omits it for http", func(t *testing.T) {
|
||||
target80, _ := url.Parse("http://backend.local:80")
|
||||
p := newProxy("http")
|
||||
modifyResp := p.rewriteLocationFunc(target80, "", newReq("http://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "http://backend.local/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "http://public.example.com/path", resp.Header.Get("Location"),
|
||||
"backend.local:80 and backend.local are the same for http")
|
||||
})
|
||||
|
||||
t.Run("rewrites when redirect has :443 but target omits it", func(t *testing.T) {
|
||||
targetNoPort, _ := url.Parse("https://heise.de")
|
||||
p := newProxy("https")
|
||||
modifyResp := p.rewriteLocationFunc(targetNoPort, "", newReq("https://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "https://heise.de:443/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/path", resp.Header.Get("Location"),
|
||||
"heise.de and heise.de:443 are the same for https")
|
||||
})
|
||||
|
||||
t.Run("does not conflate non-default ports", func(t *testing.T) {
|
||||
target8443, _ := url.Parse("https://backend.internal:8443")
|
||||
p := newProxy("https")
|
||||
modifyResp := p.rewriteLocationFunc(target8443, "", newReq("https://public.example.com/")) //nolint:bodyclose
|
||||
resp := &http.Response{Header: http.Header{}}
|
||||
resp.Header.Set("Location", "https://backend.internal/path")
|
||||
|
||||
err := modifyResp(resp)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://backend.internal/path", resp.Header.Get("Location"),
|
||||
"backend.internal:8443 != backend.internal (port 443), must not rewrite")
|
||||
})
|
||||
|
||||
// --- Edge cases: encoded paths ---
|
||||
|
||||
t.Run("preserves percent-encoded path segments", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "", newReq("https://public.example.com/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/path%20with%20spaces/file%2Fname")
|
||||
|
||||
require.NoError(t, err)
|
||||
loc := resp.Header.Get("Location")
|
||||
assert.Contains(t, loc, "public.example.com")
|
||||
parsed, err := url.Parse(loc)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "/path with spaces/file/name", parsed.Path)
|
||||
})
|
||||
|
||||
t.Run("preserves encoded query parameters with path prefix", func(t *testing.T) {
|
||||
resp, err := run(newProxy("https"), "/v1", newReq("https://public.example.com/v1/"), //nolint:bodyclose
|
||||
"http://backend.internal:8080/redirect?url=http%3A%2F%2Fexample.com")
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "https://public.example.com/v1/redirect?url=http%3A%2F%2Fexample.com", resp.Header.Get("Location"))
|
||||
})
|
||||
}
|
||||
|
||||
// newProxyRequest creates an httputil.ProxyRequest suitable for testing
|
||||
// the Rewrite function. It simulates what httputil.ReverseProxy does internally:
|
||||
// Out is a shallow clone of In with headers copied.
|
||||
func newProxyRequest(t *testing.T, rawURL, remoteAddr string) *httputil.ProxyRequest {
|
||||
t.Helper()
|
||||
|
||||
parsed, err := url.Parse(rawURL)
|
||||
require.NoError(t, err)
|
||||
|
||||
in := httptest.NewRequest(http.MethodGet, rawURL, nil)
|
||||
in.RemoteAddr = remoteAddr
|
||||
in.Host = parsed.Host
|
||||
|
||||
out := in.Clone(in.Context())
|
||||
out.Header = in.Header.Clone()
|
||||
|
||||
return &httputil.ProxyRequest{In: in, Out: out}
|
||||
}
|
||||
|
||||
func TestClassifyProxyError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
wantTitle string
|
||||
wantCode int
|
||||
wantStatus web.ErrorStatus
|
||||
}{
|
||||
{
|
||||
name: "context deadline exceeded",
|
||||
err: context.DeadlineExceeded,
|
||||
wantTitle: "Request Timeout",
|
||||
wantCode: http.StatusGatewayTimeout,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "wrapped deadline exceeded",
|
||||
err: fmt.Errorf("dial: %w", context.DeadlineExceeded),
|
||||
wantTitle: "Request Timeout",
|
||||
wantCode: http.StatusGatewayTimeout,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "context canceled",
|
||||
err: context.Canceled,
|
||||
wantTitle: "Request Canceled",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "no account ID",
|
||||
err: roundtrip.ErrNoAccountID,
|
||||
wantTitle: "Configuration Error",
|
||||
wantCode: http.StatusInternalServerError,
|
||||
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "no peer connection",
|
||||
err: fmt.Errorf("%w for account: abc", roundtrip.ErrNoPeerConnection),
|
||||
wantTitle: "Proxy Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "client not started",
|
||||
err: fmt.Errorf("%w: %w", roundtrip.ErrClientStartFailed, errors.New("engine init failed")),
|
||||
wantTitle: "Proxy Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: false, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "syscall ECONNREFUSED via os.SyscallError",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ECONNREFUSED},
|
||||
},
|
||||
wantTitle: "Service Unavailable",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "gvisor connection was refused",
|
||||
err: &net.OpError{
|
||||
Op: "connect",
|
||||
Net: "tcp",
|
||||
Err: errors.New("connection was refused"),
|
||||
},
|
||||
wantTitle: "Service Unavailable",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "syscall EHOSTUNREACH via os.SyscallError",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "syscall ENETUNREACH via os.SyscallError",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.ENETUNREACH},
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "gvisor host is unreachable",
|
||||
err: &net.OpError{
|
||||
Op: "connect",
|
||||
Net: "tcp",
|
||||
Err: errors.New("host is unreachable"),
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "gvisor network is unreachable",
|
||||
err: &net.OpError{
|
||||
Op: "connect",
|
||||
Net: "tcp",
|
||||
Err: errors.New("network is unreachable"),
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "standard no route to host",
|
||||
err: &net.OpError{
|
||||
Op: "dial",
|
||||
Net: "tcp",
|
||||
Err: &os.SyscallError{Syscall: "connect", Err: syscall.EHOSTUNREACH},
|
||||
},
|
||||
wantTitle: "Peer Not Connected",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
{
|
||||
name: "unknown error falls to default",
|
||||
err: errors.New("something unexpected"),
|
||||
wantTitle: "Connection Error",
|
||||
wantCode: http.StatusBadGateway,
|
||||
wantStatus: web.ErrorStatus{Proxy: true, Destination: false},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
title, _, code, status := classifyProxyError(tt.err)
|
||||
assert.Equal(t, tt.wantTitle, title, "title")
|
||||
assert.Equal(t, tt.wantCode, code, "status code")
|
||||
assert.Equal(t, tt.wantStatus, status, "component status")
|
||||
})
|
||||
}
|
||||
}
|
||||
84
proxy/internal/proxy/servicemapping.go
Normal file
84
proxy/internal/proxy/servicemapping.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/netbirdio/netbird/proxy/internal/types"
|
||||
)
|
||||
|
||||
type Mapping struct {
|
||||
ID string
|
||||
AccountID types.AccountID
|
||||
Host string
|
||||
Paths map[string]*url.URL
|
||||
PassHostHeader bool
|
||||
RewriteRedirects bool
|
||||
}
|
||||
|
||||
type targetResult struct {
|
||||
url *url.URL
|
||||
matchedPath string
|
||||
serviceID string
|
||||
accountID types.AccountID
|
||||
passHostHeader bool
|
||||
rewriteRedirects bool
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) findTargetForRequest(req *http.Request) (targetResult, bool) {
|
||||
p.mappingsMux.RLock()
|
||||
defer p.mappingsMux.RUnlock()
|
||||
|
||||
// Strip port from host if present (e.g., "external.test:8443" -> "external.test")
|
||||
host := req.Host
|
||||
if h, _, err := net.SplitHostPort(host); err == nil {
|
||||
host = h
|
||||
}
|
||||
|
||||
m, exists := p.mappings[host]
|
||||
if !exists {
|
||||
p.logger.Debugf("no mapping found for host: %s", host)
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
// Sort paths by length (longest first) in a naive attempt to match the most specific route first.
|
||||
paths := make([]string, 0, len(m.Paths))
|
||||
for path := range m.Paths {
|
||||
paths = append(paths, path)
|
||||
}
|
||||
sort.Slice(paths, func(i, j int) bool {
|
||||
return len(paths[i]) > len(paths[j])
|
||||
})
|
||||
|
||||
for _, path := range paths {
|
||||
if strings.HasPrefix(req.URL.Path, path) {
|
||||
target := m.Paths[path]
|
||||
p.logger.Debugf("matched host: %s, path: %s -> %s", host, path, target)
|
||||
return targetResult{
|
||||
url: target,
|
||||
matchedPath: path,
|
||||
serviceID: m.ID,
|
||||
accountID: m.AccountID,
|
||||
passHostHeader: m.PassHostHeader,
|
||||
rewriteRedirects: m.RewriteRedirects,
|
||||
}, true
|
||||
}
|
||||
}
|
||||
p.logger.Debugf("no path match for host: %s, path: %s", host, req.URL.Path)
|
||||
return targetResult{}, false
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) AddMapping(m Mapping) {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
p.mappings[m.Host] = m
|
||||
}
|
||||
|
||||
func (p *ReverseProxy) RemoveMapping(m Mapping) {
|
||||
p.mappingsMux.Lock()
|
||||
defer p.mappingsMux.Unlock()
|
||||
delete(p.mappings, m.Host)
|
||||
}
|
||||
60
proxy/internal/proxy/trustedproxy.go
Normal file
60
proxy/internal/proxy/trustedproxy.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes.
|
||||
func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool {
|
||||
if len(trusted) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
addr, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, prefix := range trusted {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list.
|
||||
// It walks the XFF chain right-to-left, skipping IPs that match trusted prefixes.
|
||||
// The first untrusted IP is the real client.
|
||||
//
|
||||
// If the trusted list is empty or remoteAddr is not trusted, it returns the
|
||||
// remoteAddr IP directly (ignoring any forwarding headers).
|
||||
func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string {
|
||||
remoteIP := extractClientIP(remoteAddr)
|
||||
|
||||
if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
if xff == "" {
|
||||
return remoteIP
|
||||
}
|
||||
|
||||
parts := strings.Split(xff, ",")
|
||||
for i := len(parts) - 1; i >= 0; i-- {
|
||||
ip := strings.TrimSpace(parts[i])
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
if !IsTrustedProxy(ip, trusted) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
|
||||
// All IPs in XFF are trusted; return the leftmost as best guess.
|
||||
if first := strings.TrimSpace(parts[0]); first != "" {
|
||||
return first
|
||||
}
|
||||
return remoteIP
|
||||
}
|
||||
129
proxy/internal/proxy/trustedproxy_test.go
Normal file
129
proxy/internal/proxy/trustedproxy_test.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIsTrustedProxy(t *testing.T) {
|
||||
trusted := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("fd00::/8"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
trusted []netip.Prefix
|
||||
want bool
|
||||
}{
|
||||
{"empty trusted list", "10.0.0.1", nil, false},
|
||||
{"IP within /8 prefix", "10.1.2.3", trusted, true},
|
||||
{"IP within /24 prefix", "192.168.1.100", trusted, true},
|
||||
{"IP outside all prefixes", "203.0.113.50", trusted, false},
|
||||
{"boundary IP just outside prefix", "192.168.2.1", trusted, false},
|
||||
{"unparsable IP", "not-an-ip", trusted, false},
|
||||
{"IPv6 in trusted range", "fd00::1", trusted, true},
|
||||
{"IPv6 outside range", "2001:db8::1", trusted, false},
|
||||
{"empty string", "", trusted, false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, IsTrustedProxy(tt.ip, tt.trusted))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveClientIP(t *testing.T) {
|
||||
trusted := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("172.16.0.0/12"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
remoteAddr string
|
||||
xff string
|
||||
trusted []netip.Prefix
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "empty trusted list returns RemoteAddr",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4",
|
||||
trusted: nil,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "untrusted RemoteAddr ignores XFF",
|
||||
remoteAddr: "203.0.113.50:9999",
|
||||
xff: "1.2.3.4, 10.0.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with single client in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr walks past trusted entries in XFF",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50, 10.0.0.2, 172.16.0.5",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.1",
|
||||
},
|
||||
{
|
||||
name: "all XFF IPs trusted returns leftmost",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "10.0.0.2, 172.16.0.1, 10.0.0.3",
|
||||
trusted: trusted,
|
||||
want: "10.0.0.2",
|
||||
},
|
||||
{
|
||||
name: "XFF with whitespace",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: " 203.0.113.50 , 10.0.0.2 ",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "XFF with empty segments",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "203.0.113.50,,10.0.0.2",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "multi-hop with mixed trust",
|
||||
remoteAddr: "10.0.0.1:5000",
|
||||
xff: "8.8.8.8, 203.0.113.50, 172.16.0.1",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
{
|
||||
name: "RemoteAddr without port",
|
||||
remoteAddr: "10.0.0.1",
|
||||
xff: "203.0.113.50",
|
||||
trusted: trusted,
|
||||
want: "203.0.113.50",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, ResolveClientIP(tt.remoteAddr, tt.xff, tt.trusted))
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user