From ed58659a01976d54a5899ede3d0d4e8b4e6aa203 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 8 Feb 2026 17:31:10 +0800 Subject: [PATCH] Set forwarded headers from trusted proxies only --- proxy/cmd/proxy/cmd/root.go | 8 ++ proxy/internal/accesslog/logger.go | 17 ++- proxy/internal/accesslog/middleware.go | 7 +- proxy/internal/accesslog/requestip.go | 45 ++----- proxy/internal/proxy/context.go | 15 +++ proxy/internal/proxy/reverseproxy.go | 89 +++++++++++-- proxy/internal/proxy/reverseproxy_test.go | 153 ++++++++++++++++++++++ proxy/internal/proxy/trustedproxy.go | 60 +++++++++ proxy/internal/proxy/trustedproxy_test.go | 129 ++++++++++++++++++ proxy/server.go | 9 +- proxy/trustedproxy.go | 43 ++++++ proxy/trustedproxy_test.go | 90 +++++++++++++ 12 files changed, 608 insertions(+), 57 deletions(-) create mode 100644 proxy/internal/proxy/trustedproxy.go create mode 100644 proxy/internal/proxy/trustedproxy_test.go create mode 100644 proxy/trustedproxy.go create mode 100644 proxy/trustedproxy_test.go diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 8908fd58b..7c0cfb0e3 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -44,6 +44,7 @@ var ( oidcEndpoint string oidcScopes string forwardedProto string + trustedProxies string ) var rootCmd = &cobra.Command{ @@ -72,6 +73,7 @@ func init() { rootCmd.Flags().StringVar(&oidcEndpoint, "oidc-endpoint", envStringOrDefault("NB_PROXY_OIDC_ENDPOINT", ""), "The OIDC Endpoint for OIDC User Authentication") rootCmd.Flags().StringVar(&oidcScopes, "oidc-scopes", envStringOrDefault("NB_PROXY_OIDC_SCOPES", "openid,profile,email"), "The OAuth2 scopes for OIDC User Authentication, comma separated") rootCmd.Flags().StringVar(&forwardedProto, "forwarded-proto", envStringOrDefault("NB_PROXY_FORWARDED_PROTO", "auto"), "X-Forwarded-Proto value for backends: auto, http, or https") + rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", envStringOrDefault("NB_PROXY_TRUSTED_PROXIES", ""), "Comma-separated list of trusted upstream proxy CIDR ranges (e.g. '10.0.0.0/8,192.168.1.1')") } // Execute runs the root command. @@ -113,6 +115,11 @@ func runServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("invalid --forwarded-proto value %q: must be auto, http, or https", forwardedProto) } + parsedTrustedProxies, err := proxy.ParseTrustedProxies(trustedProxies) + if err != nil { + return fmt.Errorf("invalid --trusted-proxies: %w", err) + } + srv := proxy.Server{ Logger: logger, Version: Version, @@ -131,6 +138,7 @@ func runServer(cmd *cobra.Command, args []string) error { OIDCEndpoint: oidcEndpoint, OIDCScopes: strings.Split(oidcScopes, ","), ForwardedProto: forwardedProto, + TrustedProxies: parsedTrustedProxies, } if err := srv.ListenAndServe(context.TODO(), addr); err != nil { diff --git a/proxy/internal/accesslog/logger.go b/proxy/internal/accesslog/logger.go index b23f79b58..8640b831d 100644 --- a/proxy/internal/accesslog/logger.go +++ b/proxy/internal/accesslog/logger.go @@ -2,6 +2,7 @@ package accesslog import ( "context" + "net/netip" log "github.com/sirupsen/logrus" "google.golang.org/grpc" @@ -14,18 +15,24 @@ type gRPCClient interface { SendAccessLog(ctx context.Context, in *proto.SendAccessLogRequest, opts ...grpc.CallOption) (*proto.SendAccessLogResponse, error) } +// Logger sends access log entries to the management server via gRPC. type Logger struct { - client gRPCClient - logger *log.Logger + client gRPCClient + logger *log.Logger + trustedProxies []netip.Prefix } -func NewLogger(client gRPCClient, logger *log.Logger) *Logger { +// NewLogger creates a new access log Logger. The trustedProxies parameter +// configures which upstream proxy IP ranges are trusted for extracting +// the real client IP from X-Forwarded-For headers. +func NewLogger(client gRPCClient, logger *log.Logger, trustedProxies []netip.Prefix) *Logger { if logger == nil { logger = log.StandardLogger() } return &Logger{ - client: client, - logger: logger, + client: client, + logger: logger, + trustedProxies: trustedProxies, } } diff --git a/proxy/internal/accesslog/middleware.go b/proxy/internal/accesslog/middleware.go index 12828d50e..48d6b61b3 100644 --- a/proxy/internal/accesslog/middleware.go +++ b/proxy/internal/accesslog/middleware.go @@ -24,14 +24,15 @@ func (l *Logger) Middleware(next http.Handler) http.Handler { status: http.StatusOK, } - // Get the source IP before passing the request on as the proxy will modify - // headers that we wish to use to gather that information on the request. - sourceIp := extractSourceIP(r) + // Resolve the source IP using trusted proxy configuration before passing + // the request on, as the proxy will modify forwarding headers. + sourceIp := extractSourceIP(r, l.trustedProxies) // Create a mutable struct to capture data from downstream handlers. // We pass a pointer in the context - the pointer itself flows down immutably, // but the struct it points to can be mutated by inner handlers. capturedData := &proxy.CapturedData{RequestID: requestID} + capturedData.SetClientIP(sourceIp) ctx := proxy.WithCapturedData(r.Context(), capturedData) start := time.Now() diff --git a/proxy/internal/accesslog/requestip.go b/proxy/internal/accesslog/requestip.go index 9456a608f..f111c1322 100644 --- a/proxy/internal/accesslog/requestip.go +++ b/proxy/internal/accesslog/requestip.go @@ -1,43 +1,16 @@ package accesslog import ( - "net" "net/http" - "slices" - "strings" + "net/netip" + + "github.com/netbirdio/netbird/proxy/internal/proxy" ) -// requestIP attempts to extract the source IP from a request. -// Adapted from https://husobee.github.io/golang/ip-address/2015/12/17/remote-ip-go.html -// with the addition of some newer stdlib functions that are now -// available. -// The concept here is to look backwards through IP headers until -// the first public IP address is found. The hypothesis is that -// even if there are multiple IP addresses specified in these headers, -// the last public IP should be the hop immediately before reaching -// the server and therefore represents the "true" source IP regardless -// of the number of intermediate proxies or network hops. -func extractSourceIP(r *http.Request) string { - for _, h := range []string{"X-Forwarded-For", "X-Real-IP"} { - addresses := strings.Split(r.Header.Get(h), ",") - // Iterate from right to left until we get a public address - // that should be the address right before our proxy. - for _, address := range slices.Backward(addresses) { - // Trim the address because sometimes clients put whitespace in there. - ip := strings.TrimSpace(address) - // Parse the IP so that we can easily check whether it is a valid public address. - realIP := net.ParseIP(ip) - if !realIP.IsGlobalUnicast() || realIP.IsPrivate() || realIP.IsLoopback() { - continue - } - return ip - } - } - // Fallback to the requests RemoteAddr, this is least likely to be correct but - // should at least yield something in the event that the above has failed. - ip, _, err := net.SplitHostPort(r.RemoteAddr) - if err != nil { - ip = r.RemoteAddr - } - return ip +// extractSourceIP resolves the real client IP from the request using trusted +// proxy configuration. When trustedProxies is non-empty and the direct +// connection is from a trusted source, it walks X-Forwarded-For right-to-left +// skipping trusted IPs. Otherwise it returns RemoteAddr directly. +func extractSourceIP(r *http.Request, trustedProxies []netip.Prefix) string { + return proxy.ResolveClientIP(r.RemoteAddr, r.Header.Get("X-Forwarded-For"), trustedProxies) } diff --git a/proxy/internal/proxy/context.go b/proxy/internal/proxy/context.go index a60d14b0b..460f04ed0 100644 --- a/proxy/internal/proxy/context.go +++ b/proxy/internal/proxy/context.go @@ -50,6 +50,7 @@ type CapturedData struct { ServiceId string AccountId types.AccountID Origin ResponseOrigin + ClientIP string } // GetRequestID safely gets the request ID @@ -101,6 +102,20 @@ func (c *CapturedData) GetOrigin() ResponseOrigin { return c.Origin } +// SetClientIP safely sets the resolved client IP. +func (c *CapturedData) SetClientIP(ip string) { + c.mu.Lock() + defer c.mu.Unlock() + c.ClientIP = ip +} + +// GetClientIP safely gets the resolved client IP. +func (c *CapturedData) GetClientIP() string { + c.mu.RLock() + defer c.mu.RUnlock() + return c.ClientIP +} + // WithCapturedData adds a CapturedData struct to the context func WithCapturedData(ctx context.Context, data *CapturedData) context.Context { return context.WithValue(ctx, capturedDataKey, data) diff --git a/proxy/internal/proxy/reverseproxy.go b/proxy/internal/proxy/reverseproxy.go index 99954fe4c..f348dd389 100644 --- a/proxy/internal/proxy/reverseproxy.go +++ b/proxy/internal/proxy/reverseproxy.go @@ -6,6 +6,7 @@ import ( "net" "net/http" "net/http/httputil" + "net/netip" "net/url" "strings" "sync" @@ -22,6 +23,10 @@ type ReverseProxy struct { // forwardedProto overrides the X-Forwarded-Proto header value. // Valid values: "auto" (detect from TLS), "http", "https". forwardedProto string + // trustedProxies is a list of IP prefixes for trusted upstream proxies. + // When the direct connection comes from a trusted proxy, forwarding + // headers are preserved and appended to instead of being stripped. + trustedProxies []netip.Prefix mappingsMux sync.RWMutex mappings map[string]Mapping logger *log.Logger @@ -33,13 +38,14 @@ type ReverseProxy struct { // between requested URLs and targets. // The internal mappings can be modified using the AddMapping // and RemoveMapping functions. -func NewReverseProxy(transport http.RoundTripper, forwardedProto string, logger *log.Logger) *ReverseProxy { +func NewReverseProxy(transport http.RoundTripper, forwardedProto string, trustedProxies []netip.Prefix, logger *log.Logger) *ReverseProxy { if logger == nil { logger = log.StandardLogger() } return &ReverseProxy{ transport: transport, forwardedProto: forwardedProto, + trustedProxies: trustedProxies, mappings: make(map[string]Mapping), logger: logger, } @@ -96,21 +102,73 @@ func (p *ReverseProxy) rewriteFunc(target *url.URL, passHostHeader bool) func(r } clientIP := extractClientIP(r.In.RemoteAddr) - proto := auth.ResolveProto(p.forwardedProto, r.In.TLS) - // Strip any incoming forwarding headers since this proxy is the trust - // boundary and set them fresh based on the direct connection. - r.Out.Header.Set("X-Forwarded-For", clientIP) - r.Out.Header.Set("X-Real-IP", clientIP) - r.Out.Header.Set("X-Forwarded-Host", r.In.Host) - r.Out.Header.Set("X-Forwarded-Proto", proto) - r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto)) + if IsTrustedProxy(clientIP, p.trustedProxies) { + p.setTrustedForwardingHeaders(r, clientIP) + } else { + p.setUntrustedForwardingHeaders(r, clientIP) + } stripSessionCookie(r) stripSessionTokenQuery(r) } } +// setTrustedForwardingHeaders appends to the existing forwarding header chain +// and preserves upstream-provided headers when the direct connection is from +// a trusted proxy. +func (p *ReverseProxy) setTrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { + // Append the direct connection IP to the existing X-Forwarded-For chain. + if existing := r.In.Header.Get("X-Forwarded-For"); existing != "" { + r.Out.Header.Set("X-Forwarded-For", existing+", "+clientIP) + } else { + r.Out.Header.Set("X-Forwarded-For", clientIP) + } + + // Preserve upstream X-Real-IP if present; otherwise resolve through the chain. + if realIP := r.In.Header.Get("X-Real-IP"); realIP != "" { + r.Out.Header.Set("X-Real-IP", realIP) + } else { + resolved := ResolveClientIP(r.In.RemoteAddr, r.In.Header.Get("X-Forwarded-For"), p.trustedProxies) + r.Out.Header.Set("X-Real-IP", resolved) + } + + // Preserve upstream X-Forwarded-Host if present. + if fwdHost := r.In.Header.Get("X-Forwarded-Host"); fwdHost != "" { + r.Out.Header.Set("X-Forwarded-Host", fwdHost) + } else { + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + } + + // Trust upstream X-Forwarded-Proto; fall back to local resolution. + if fwdProto := r.In.Header.Get("X-Forwarded-Proto"); fwdProto != "" { + r.Out.Header.Set("X-Forwarded-Proto", fwdProto) + } else { + r.Out.Header.Set("X-Forwarded-Proto", auth.ResolveProto(p.forwardedProto, r.In.TLS)) + } + + // Trust upstream X-Forwarded-Port; fall back to local computation. + if fwdPort := r.In.Header.Get("X-Forwarded-Port"); fwdPort != "" { + r.Out.Header.Set("X-Forwarded-Port", fwdPort) + } else { + resolvedProto := r.Out.Header.Get("X-Forwarded-Proto") + r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, resolvedProto)) + } +} + +// setUntrustedForwardingHeaders strips all incoming forwarding headers and +// sets them fresh based on the direct connection. This is the default +// behavior when no trusted proxies are configured or the direct connection +// is from an untrusted source. +func (p *ReverseProxy) setUntrustedForwardingHeaders(r *httputil.ProxyRequest, clientIP string) { + proto := auth.ResolveProto(p.forwardedProto, r.In.TLS) + r.Out.Header.Set("X-Forwarded-For", clientIP) + r.Out.Header.Set("X-Real-IP", clientIP) + r.Out.Header.Set("X-Forwarded-Host", r.In.Host) + r.Out.Header.Set("X-Forwarded-Proto", proto) + r.Out.Header.Set("X-Forwarded-Port", extractForwardedPort(r.In.Host, proto)) +} + // stripSessionCookie removes the proxy's session cookie from the outgoing // request while preserving all other cookies. func stripSessionCookie(r *httputil.ProxyRequest) { @@ -163,14 +221,23 @@ func proxyErrorHandler(w http.ResponseWriter, r *http.Request, err error) { cd.SetOrigin(OriginProxyError) } requestID := getRequestID(r) + clientIP := getClientIP(r) title, message, code, status := classifyProxyError(err) - log.Warnf("proxy error: request_id=%s method=%s host=%s path=%s status=%d title=%q err=%v", - requestID, r.Method, r.Host, r.URL.Path, code, title, err) + log.Warnf("proxy error: request_id=%s client_ip=%s method=%s host=%s path=%s status=%d title=%q err=%v", + requestID, clientIP, r.Method, r.Host, r.URL.Path, code, title, err) web.ServeErrorPage(w, r, code, title, message, requestID, status) } +// getClientIP retrieves the resolved client IP from context. +func getClientIP(r *http.Request) string { + if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil { + return capturedData.GetClientIP() + } + return "" +} + // getRequestID retrieves the request ID from context or returns empty string. func getRequestID(r *http.Request) string { if capturedData := CapturedDataFromContext(r.Context()); capturedData != nil { diff --git a/proxy/internal/proxy/reverseproxy_test.go b/proxy/internal/proxy/reverseproxy_test.go index a8038bc1d..8a2ad209c 100644 --- a/proxy/internal/proxy/reverseproxy_test.go +++ b/proxy/internal/proxy/reverseproxy_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "net/http/httputil" + "net/netip" "net/url" "testing" @@ -293,6 +294,158 @@ func TestExtractForwardedPort(t *testing.T) { } } +func TestRewriteFunc_TrustedProxy(t *testing.T) { + target, _ := url.Parse("http://backend.internal:8080") + trusted := []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")} + + t.Run("appends to X-Forwarded-For", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") + + rewrite(pr) + + assert.Equal(t, "203.0.113.50, 10.0.0.1", pr.Out.Header.Get("X-Forwarded-For")) + }) + + t.Run("preserves upstream X-Real-IP", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") + pr.In.Header.Set("X-Real-IP", "203.0.113.50") + + rewrite(pr) + + assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP")) + }) + + t.Run("resolves X-Real-IP from XFF when not set by upstream", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + pr.In.Header.Set("X-Forwarded-For", "203.0.113.50, 10.0.0.2") + + rewrite(pr) + + assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"), + "should resolve real client through trusted chain") + }) + + t.Run("preserves upstream X-Forwarded-Host", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://proxy.internal/", "10.0.0.1:5000") + pr.In.Header.Set("X-Forwarded-Host", "original.example.com") + + rewrite(pr) + + assert.Equal(t, "original.example.com", pr.Out.Header.Get("X-Forwarded-Host")) + }) + + t.Run("preserves upstream X-Forwarded-Proto", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + pr.In.Header.Set("X-Forwarded-Proto", "https") + + rewrite(pr) + + assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto")) + }) + + t.Run("preserves upstream X-Forwarded-Port", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + pr.In.Header.Set("X-Forwarded-Port", "8443") + + rewrite(pr) + + assert.Equal(t, "8443", pr.Out.Header.Get("X-Forwarded-Port")) + }) + + t.Run("falls back to local proto when upstream does not set it", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "https", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + + rewrite(pr) + + assert.Equal(t, "https", pr.Out.Header.Get("X-Forwarded-Proto"), + "should use configured forwardedProto as fallback") + }) + + t.Run("sets X-Forwarded-Host from request when upstream does not set it", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + + rewrite(pr) + + assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host")) + }) + + t.Run("untrusted RemoteAddr strips headers even with trusted list", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "203.0.113.50:9999") + pr.In.Header.Set("X-Forwarded-For", "10.0.0.1, 172.16.0.1") + pr.In.Header.Set("X-Real-IP", "evil") + pr.In.Header.Set("X-Forwarded-Host", "evil.example.com") + pr.In.Header.Set("X-Forwarded-Proto", "https") + pr.In.Header.Set("X-Forwarded-Port", "9999") + + rewrite(pr) + + assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Forwarded-For"), + "untrusted: XFF must be replaced") + assert.Equal(t, "203.0.113.50", pr.Out.Header.Get("X-Real-IP"), + "untrusted: X-Real-IP must be replaced") + assert.Equal(t, "example.com", pr.Out.Header.Get("X-Forwarded-Host"), + "untrusted: host must be from direct connection") + assert.Equal(t, "http", pr.Out.Header.Get("X-Forwarded-Proto"), + "untrusted: proto must be locally resolved") + assert.Equal(t, "80", pr.Out.Header.Get("X-Forwarded-Port"), + "untrusted: port must be locally computed") + }) + + t.Run("empty trusted list behaves as untrusted", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: nil} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + pr.In.Header.Set("X-Forwarded-For", "203.0.113.50") + + rewrite(pr) + + assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"), + "nil trusted list: should strip and use RemoteAddr") + }) + + t.Run("XFF starts fresh when trusted proxy has no upstream XFF", func(t *testing.T) { + p := &ReverseProxy{forwardedProto: "auto", trustedProxies: trusted} + rewrite := p.rewriteFunc(target, false) + + pr := newProxyRequest(t, "http://example.com/", "10.0.0.1:5000") + + rewrite(pr) + + assert.Equal(t, "10.0.0.1", pr.Out.Header.Get("X-Forwarded-For"), + "no upstream XFF: should set direct connection IP") + }) +} + // newProxyRequest creates an httputil.ProxyRequest suitable for testing // the Rewrite function. It simulates what httputil.ReverseProxy does internally: // Out is a shallow clone of In with headers copied. diff --git a/proxy/internal/proxy/trustedproxy.go b/proxy/internal/proxy/trustedproxy.go new file mode 100644 index 000000000..ad9a5b6c0 --- /dev/null +++ b/proxy/internal/proxy/trustedproxy.go @@ -0,0 +1,60 @@ +package proxy + +import ( + "net/netip" + "strings" +) + +// IsTrustedProxy checks if the given IP string falls within any of the trusted prefixes. +func IsTrustedProxy(ipStr string, trusted []netip.Prefix) bool { + if len(trusted) == 0 { + return false + } + + addr, err := netip.ParseAddr(ipStr) + if err != nil { + return false + } + + for _, prefix := range trusted { + if prefix.Contains(addr) { + return true + } + } + return false +} + +// ResolveClientIP extracts the real client IP from X-Forwarded-For using the trusted proxy list. +// It walks the XFF chain right-to-left, skipping IPs that match trusted prefixes. +// The first untrusted IP is the real client. +// +// If the trusted list is empty or remoteAddr is not trusted, it returns the +// remoteAddr IP directly (ignoring any forwarding headers). +func ResolveClientIP(remoteAddr, xff string, trusted []netip.Prefix) string { + remoteIP := extractClientIP(remoteAddr) + + if len(trusted) == 0 || !IsTrustedProxy(remoteIP, trusted) { + return remoteIP + } + + if xff == "" { + return remoteIP + } + + parts := strings.Split(xff, ",") + for i := len(parts) - 1; i >= 0; i-- { + ip := strings.TrimSpace(parts[i]) + if ip == "" { + continue + } + if !IsTrustedProxy(ip, trusted) { + return ip + } + } + + // All IPs in XFF are trusted; return the leftmost as best guess. + if first := strings.TrimSpace(parts[0]); first != "" { + return first + } + return remoteIP +} diff --git a/proxy/internal/proxy/trustedproxy_test.go b/proxy/internal/proxy/trustedproxy_test.go new file mode 100644 index 000000000..4c8eb15a6 --- /dev/null +++ b/proxy/internal/proxy/trustedproxy_test.go @@ -0,0 +1,129 @@ +package proxy + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsTrustedProxy(t *testing.T) { + trusted := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("fd00::/8"), + } + + tests := []struct { + name string + ip string + trusted []netip.Prefix + want bool + }{ + {"empty trusted list", "10.0.0.1", nil, false}, + {"IP within /8 prefix", "10.1.2.3", trusted, true}, + {"IP within /24 prefix", "192.168.1.100", trusted, true}, + {"IP outside all prefixes", "203.0.113.50", trusted, false}, + {"boundary IP just outside prefix", "192.168.2.1", trusted, false}, + {"unparseable IP", "not-an-ip", trusted, false}, + {"IPv6 in trusted range", "fd00::1", trusted, true}, + {"IPv6 outside range", "2001:db8::1", trusted, false}, + {"empty string", "", trusted, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, IsTrustedProxy(tt.ip, tt.trusted)) + }) + } +} + +func TestResolveClientIP(t *testing.T) { + trusted := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("172.16.0.0/12"), + } + + tests := []struct { + name string + remoteAddr string + xff string + trusted []netip.Prefix + want string + }{ + { + name: "empty trusted list returns RemoteAddr", + remoteAddr: "203.0.113.50:9999", + xff: "1.2.3.4", + trusted: nil, + want: "203.0.113.50", + }, + { + name: "untrusted RemoteAddr ignores XFF", + remoteAddr: "203.0.113.50:9999", + xff: "1.2.3.4, 10.0.0.1", + trusted: trusted, + want: "203.0.113.50", + }, + { + name: "trusted RemoteAddr with single client in XFF", + remoteAddr: "10.0.0.1:5000", + xff: "203.0.113.50", + trusted: trusted, + want: "203.0.113.50", + }, + { + name: "trusted RemoteAddr walks past trusted entries in XFF", + remoteAddr: "10.0.0.1:5000", + xff: "203.0.113.50, 10.0.0.2, 172.16.0.5", + trusted: trusted, + want: "203.0.113.50", + }, + { + name: "trusted RemoteAddr with empty XFF falls back to RemoteAddr", + remoteAddr: "10.0.0.1:5000", + xff: "", + trusted: trusted, + want: "10.0.0.1", + }, + { + name: "all XFF IPs trusted returns leftmost", + remoteAddr: "10.0.0.1:5000", + xff: "10.0.0.2, 172.16.0.1, 10.0.0.3", + trusted: trusted, + want: "10.0.0.2", + }, + { + name: "XFF with whitespace", + remoteAddr: "10.0.0.1:5000", + xff: " 203.0.113.50 , 10.0.0.2 ", + trusted: trusted, + want: "203.0.113.50", + }, + { + name: "XFF with empty segments", + remoteAddr: "10.0.0.1:5000", + xff: "203.0.113.50,,10.0.0.2", + trusted: trusted, + want: "203.0.113.50", + }, + { + name: "multi-hop with mixed trust", + remoteAddr: "10.0.0.1:5000", + xff: "8.8.8.8, 203.0.113.50, 172.16.0.1", + trusted: trusted, + want: "203.0.113.50", + }, + { + name: "RemoteAddr without port", + remoteAddr: "10.0.0.1", + xff: "203.0.113.50", + trusted: trusted, + want: "203.0.113.50", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, ResolveClientIP(tt.remoteAddr, tt.xff, tt.trusted)) + }) + } +} diff --git a/proxy/server.go b/proxy/server.go index 582bfe6eb..c3c6ca218 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -16,6 +16,7 @@ import ( "io" "net" "net/http" + "net/netip" "net/url" "path/filepath" "time" @@ -82,6 +83,10 @@ type Server struct { // ForwardedProto overrides the X-Forwarded-Proto value sent to backends. // Valid values: "auto" (detect from TLS), "http", "https". ForwardedProto string + // TrustedProxies is a list of IP prefixes for trusted upstream proxies. + // When set, forwarding headers from these sources are preserved and + // appended to instead of being stripped. + TrustedProxies []netip.Prefix } // NotifyStatus sends a status update to management about tunnel connectivity @@ -217,13 +222,13 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) { } // Configure the reverse proxy using NetBird's HTTP Client Transport for proxying. - s.proxy = proxy.NewReverseProxy(s.netbird, s.ForwardedProto, s.Logger) + s.proxy = proxy.NewReverseProxy(s.netbird, s.ForwardedProto, s.TrustedProxies, s.Logger) // Configure the authentication middleware. s.auth = auth.NewMiddleware(s.Logger) // Configure Access logs to management server. - accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger) + accessLog := accesslog.NewLogger(s.mgmtClient, s.Logger, s.TrustedProxies) if s.DebugEndpointEnabled { debugAddr := debugEndpointAddr(s.DebugEndpointAddress) diff --git a/proxy/trustedproxy.go b/proxy/trustedproxy.go new file mode 100644 index 000000000..3a1f0ad37 --- /dev/null +++ b/proxy/trustedproxy.go @@ -0,0 +1,43 @@ +package proxy + +import ( + "fmt" + "net/netip" + "strings" +) + +// ParseTrustedProxies parses a comma-separated list of CIDR prefixes or bare IPs +// into a slice of netip.Prefix values suitable for trusted proxy configuration. +// Bare IPs are converted to single-host prefixes (/32 or /128). +func ParseTrustedProxies(raw string) ([]netip.Prefix, error) { + if raw == "" { + return nil, nil + } + + parts := strings.Split(raw, ",") + prefixes := make([]netip.Prefix, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + prefix, err := netip.ParsePrefix(part) + if err == nil { + prefixes = append(prefixes, prefix) + continue + } + + addr, addrErr := netip.ParseAddr(part) + if addrErr != nil { + return nil, fmt.Errorf("parse trusted proxy %q: not a valid CIDR or IP: %w", part, addrErr) + } + + bits := 32 + if addr.Is6() { + bits = 128 + } + prefixes = append(prefixes, netip.PrefixFrom(addr, bits)) + } + return prefixes, nil +} diff --git a/proxy/trustedproxy_test.go b/proxy/trustedproxy_test.go new file mode 100644 index 000000000..974e56863 --- /dev/null +++ b/proxy/trustedproxy_test.go @@ -0,0 +1,90 @@ +package proxy + +import ( + "net/netip" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseTrustedProxies(t *testing.T) { + tests := []struct { + name string + raw string + want []netip.Prefix + wantErr bool + }{ + { + name: "empty string returns nil", + raw: "", + want: nil, + }, + { + name: "single CIDR", + raw: "10.0.0.0/8", + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + { + name: "single bare IPv4", + raw: "1.2.3.4", + want: []netip.Prefix{netip.MustParsePrefix("1.2.3.4/32")}, + }, + { + name: "single bare IPv6", + raw: "::1", + want: []netip.Prefix{netip.MustParsePrefix("::1/128")}, + }, + { + name: "comma-separated CIDRs", + raw: "10.0.0.0/8, 192.168.1.0/24", + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "mixed CIDRs and bare IPs", + raw: "10.0.0.0/8, 1.2.3.4, fd00::/8", + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("1.2.3.4/32"), + netip.MustParsePrefix("fd00::/8"), + }, + }, + { + name: "whitespace around entries", + raw: " 10.0.0.0/8 , 192.168.0.0/16 ", + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + netip.MustParsePrefix("192.168.0.0/16"), + }, + }, + { + name: "trailing comma produces no extra entry", + raw: "10.0.0.0/8,", + want: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + { + name: "invalid entry", + raw: "not-an-ip", + wantErr: true, + }, + { + name: "partially invalid", + raw: "10.0.0.0/8, garbage", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseTrustedProxies(tt.raw) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +}